Available Wrappers¶
Sometimes when doing RL experiments you need some minor differences for a specific environment. Maybe you want its observation state to be in a different shape or its rewards to automatically be bounded between 0 and 1.
Wrappers are a simple way to make these types of changes and are useful for extending, enhancing, or updating a portion of an environment without modifying its source directly.
They take an inner JaxEnv, change one or more of its inputs/outputs (observations, rewards, state, done flag, info metadata), and expose the same reset/step interface so everything downstream - the VecEnv classes, make() methods, your training loop - keeps working without any changes.
This tutorial covers the pre-built wrappers provided with Envrax and how to use them in your projects.
Want to create your own?
See the Creating a Custom Wrapper
tutorial.
Types of Wrappers¶
Every Envrax wrapper falls into one of two categories - pass-through (stateless) and stateful.
The main difference is whether the wrapper introduces its own state alongside the inner environment's. If it does, it's classed as a stateful wrapper.
Simply put, stateful wrappers need to remember something across timesteps such as a rolling frame buffer, or an episode total. This ensures that the environments are still compatible with JAX's transforms (jit, vmap, scan). We'll discuss this in more depth shortly.
First though, we'll explore the simpler variant of the two: pass-through wrappers!
Pass-through Wrappers¶
API Docs
These wrappers don't introduce any new state. Instead, they just transform the desired inputs/outputs and flow through the reset/step methods like normal.
Here's a quick overview of available pass-through wrappers:
| Wrapper | Input obs | Output obs | Description |
|---|---|---|---|
JitWrapper |
any | same | Compiles reset + step with jax.jit; caches kernels to disk |
GrayscaleObservation |
uint8[H, W, 3] |
uint8[H, W] |
NTSC luminance conversion |
ResizeObservation(h, w) |
uint8[H, W] or uint8[H, W, C] |
uint8[h, w] or uint8[h, w, C] |
Bilinear resize (default 84×84) |
NormalizeObservation |
uint8[...] |
float32[...] in [0, 1] |
Divide by 255 |
ClipReward |
any reward | float32 ∈ {−1, 0, +1} |
Sign clipping |
ExpandDims |
any | same | Adds trailing size-1 dim to reward and done |
EpisodeDiscount |
any | same | Converts done bool to float32 discount (1.0 / 0.0) |
RecordVideo |
any | same | Saves episode frames to MP4 (not JIT-compatible) |
We'll dig into each one below.
JitWrapper¶
This wrapper JIT-compiles the reset and step steps and caches the resulting XLA executables to disk.
You'll find it applied automatically by the make() methods when setting jit_compile=True.
You should rarely need to construct it manually, but if you do, here's an example:
| Python | |
|---|---|
1 2 3 | |
pre_warm=Trueby default
It also exposes a compile() method so you can trigger the XLA compilation manually. This is useful when you've constructed the wrapper with pre_warm=False and want to defer the compilation cost to a separate setup phase:
| Python | |
|---|---|
1 2 3 | |
It's worth noting, compile() is safe to call multiple times. Thanks to caching, subsequent calls should be near-instant on wall-clock time making failed runs operate that little bit smoother!
GrayscaleObservation¶
Combining with Other Wrappers
When using this with the NormalizeObservation wrapper, you should always apply this before it. Grayscale expects uint8 values, not float.
This wrapper converts an RGB observation to grayscale using NTSC luminance weights (0.299 R + 0.587 G + 0.114 B).
| Input obs | Output obs |
|---|---|
uint8[H, W, 3] |
uint8[H, W] |
| Python | |
|---|---|
1 2 3 | |
ResizeObservation¶
This wrapper performs bilinear-resizing on 2-D or 3-D uint8 observations to a target height and width (h, w). The channel dimension (C) is preserved.
| Input obs | Output obs |
|---|---|
uint8[H, W] |
uint8[h, w] |
uint8[H, W, C] |
uint8[h, w, C] |
| Python | |
|---|---|
1 2 3 | |
NormalizeObservation¶
Combining with Other Wrappers
When using this with shape-transform wrappers like GrayscaleObservation and ResizeObservation, you should always apply this after them. Those wrappers expect uint8 values, not float.
This wrapper divides uint8 observations by 255 and casts them to float32, normalizing their values between the range of [0, 1].
| Input obs | Output obs |
|---|---|
uint8[...] |
float32[...] in [0, 1] |
| Python | |
|---|---|
1 2 3 | |
ClipReward¶
This wrapper sign-clips rewards to {-1, 0, +1}. It's useful as a stabilisation step when reward magnitudes can vary wildly between episodes or across environments.
| Python | |
|---|---|
1 2 3 | |
EpisodeDiscount¶
This wrapper converts the done boolean to a float32 discount factor (1.0 when not done, 0.0 when done). Useful for value bootstrapping where you want value(s') * discount to zero out at terminal states.
| Python | |
|---|---|
1 2 3 | |
ExpandDims¶
This wrapper adds a trailing size-1 dimension to reward and done so they broadcast cleanly against batched value heads.
| Python | |
|---|---|
1 2 3 4 | |
RecordVideo¶
JIT and vmap Incompatibility
RecordVideo is not JIT or vmap compatible because it writes files Python-side. Use it for evaluation, logging, or training visualisation only. Never use it inside jax.jit/jax.vmap.
Calling reset or step inside any jax.jit, jax.vmap, or jax.lax.scan boundary raises a RuntimeError.
The wrapped environment must also implement a render(state) method. Otherwise, RecordVideo will raise a TypeError at construction.
This wrapper saves episode frames to MP4 via imageio and can be customized based on three optional trigger controls.
| Python | |
|---|---|
1 2 3 | |
Output files are stored in <output_dir>/episode_<NNNN>.mp4. The wrapper requires imageio with the ffmpeg plugin, which you can install via:
1 | |
1 | |
1 | |
Trigger Controls¶
To make this wrapper more flexible, you can configure specific triggers based on your requirements to control when recording is active. If no triggers are provided, every episode is recorded.
Here are your options:
episode_trigger: Callable[[int], bool]— fires at eachreset()with the current episode indexstep_trigger: Callable[[int], bool]— fires at eachstep()with the global step count, and starts recording mid-episoderecording_trigger: Callable[[], bool]— zero-arg callable checked at eachreset(), useful for external control via a custom flag
Episode Trigger¶
Use this when you want to record on a regular cadence (e.g., every Nth episode):
| Python | |
|---|---|
1 2 3 4 5 | |
Step Trigger¶
Use this when you want to start recording mid-episode after a global step threshold. It can also be handy for skipping the first N warmup steps and starting recording afterwards:
| Python | |
|---|---|
1 2 3 4 5 | |
Once the trigger fires, recording continues until that episode ends.
Recording Trigger¶
Use this when an external system (e.g. a meta-learning loop or evaluation harness) controls when recording is active via a custom flag:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 | |
Combining Triggers¶
If you want, you can mix and match your triggers and combine them together! If any one of them returns True, recording will fire:
| Python | |
|---|---|
1 2 3 4 5 6 7 | |
Stateful Wrappers¶
API Docs
These wrappers introduce their own outer state so that they can remember the information they need to carry across timesteps.
FrameStackObservation¶
This wrapper maintains a sliding window of the last n_stack observations in a rolling manner.
This is useful when you need your agent to perceive motion.
| Input obs | Output obs |
|---|---|
uint8[H, W] |
uint8[H, W, n_stack] |
| Python | |
|---|---|
1 2 3 4 5 6 | |
RecordEpisodeStatistics¶
This wrapper tracks the cumulative return and step count of each episode.
It adds an episode entry to the info metadata on every step(), populated only when done=True, providing an episode return value (r) and episode length (l). Format:
| Python | |
|---|---|
1 2 3 4 | |
These are useful for logging episodic metrics dynamically without having to manually create them yourself!
| Python | |
|---|---|
1 2 3 4 5 6 7 8 | |
Applying Wrappers¶
We can apply wrappers through two methods:
- Using the built in
make()methods - Manually through class instances
Using make() methods¶
The easiest way to apply wrappers is through the make() methods.
Simply provide them as class types (just the class name) or their full class with custom parameters (without env=) and the selected make method will do the rest for you!
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | |
There are a few things to consider when using this approach:
-
Order matters.
Wrappers apply innermost-first, so the list operates in a top-down approach:
Grayscale Resize Frame-stack Clip reward
Swapping the order will produce different results.
-
Parameterised wrappers must be called without
env.Under the hood, parameterised wrappers (e.g.,
ResizeObservation(h=84, w=84)) return a_WrapperFactorythat themake()method finishes binding to the base environment automatically once it's constructed.There's no need for
functools.partial!
Manually¶
You can also apply the wrappers manually, without a make() method, using direct calls like so:
| Python | |
|---|---|
1 2 3 4 | |
This can be useful in unit tests or when you want to construct a wrapper chain yourself.
Accessing the Inner Environment¶
Every wrapper exposes an env.unwrapped field to give you access to the innermost (initial/base) environment.
For example, if we wrapped our BallEnv and wanted to check its base instance instead of its ClipReward variant, we could grab it using this field:
| Python | |
|---|---|
1 2 | |
This behaviour holds no matter how many wrappers you apply!
Obs and Action Space Behaviour
observation_space/action_space delegate to the inner environment by default.
Wrappers only override them when they change the space. For example, the GrayscaleObservation wrapper drops the channel dimension, so the observation_space is modified.
Common Pipelines¶
Atari-style image preprocessing:
| Python | |
|---|---|
1 2 3 4 5 6 7 | |
Training telemetry:
| Python | |
|---|---|
1 2 3 | |
Evaluation with video:
| Python | |
|---|---|
1 2 3 | |
Common Pitfalls¶
Here are some common "gotchas" to be mindful of:
- Applying
RecordVideoinside a JIT boundary. Don't do it. It writes Python-side files and should only be used for evaluation purposes, outsidejax.jit/jax.vmap. - Wrong input shape for
GrayscaleObservation. This wrapper expectsuint8[H, W, 3]. If your environment outputsfloatsor grayscale already, you get a shape/dtype error at trace time. - Ordering
NormalizeObservationbeforeGrayscaleObservation/ResizeObservation. TheNormalizeObservationwrapper turnsuint8[0, 255]observations intofloat32[0, 1]. The shape transforms expectuint8. Perform shape transforms first, then normalize them.
Recap¶
Excellent job! To recap:
- We have two types of wrappers: pass-through (stateless) and stateful (introduces an outer state type wrapping the inner state).
- Envrax comes with 8 pass-through wrappers:
JitWrapper,GrayscaleObservation,ResizeObservation,NormalizeObservation,ClipReward,EpisodeDiscount,ExpandDims,RecordVideo - And 2 stateful wrappers:
FrameStackObservation,RecordEpisodeStatistics - We apply wrappers via
make(wrappers=[...])(innermost-first) or manual composition - Parameterised wrappers can be passed to
make()methods without theenvparameter. Nofunctools.partialrequired! env.unwrappedprovides the innermost (base)JaxEnv- Always do shape transforms (
Grayscale,Resize) onuint8observations beforeNormalizeObservationcasts them tofloat32
Next Steps¶
For our last tutorial, we'll look at how to use the render() method so that you can watch your agents in their environments. See you there!
-
Rendering
Learn how to use the
render()method.