Creating a Custom Wrapper¶
You've already used Envrax's built-in wrappers from the Available Wrappers tutorial — but what if none of them fit your needs?
Maybe you want to scale your rewards by a constant, add a curriculum learning step, or track a unique statistic across each episode that none of the built-ins cover.
The easiest solution? Building your own wrappers! In this tutorial, we'll walk through how to build both kinds: a pass-through wrapper that simply transforms data flowing through reset/step, and a stateful wrapper that needs to remember something between steps.
Picking a Base Class¶
Every custom wrapper inherits from one of two base classes:
| Base | When to use | What changes |
|---|---|---|
Wrapper |
Pass-through — transforms obs/reward/done, leaves the inner state type unchanged |
Nothing |
StatefulWrapper |
Stateful — needs to remember something between steps | Introduces a new outer state that wraps the inner state |
If you're unsure which one to use, start with Wrapper. Then, if you find yourself wanting to store a counter or a running total across steps without polluting the inner environment, that's your indicator to transition it to a StatefulWrapper.
Pass-through Wrapper¶
Full Code
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | |
For this example, we'll build a simple wrapper that multiplies every reward by a constant. We'll call it ScaleReward.
We can build it in four key steps:
- Declaring the class
- Storing the scale parameter
- Implementing
reset() - Implementing
step()
Let's tackle them one at a time.
Step 1: Declaring the Class¶
Just like building a JaxEnv, we start by subclassing the base class and pinning the generic types. For pass-through wrappers, we leave all four as their defaults:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 | |
Step 2: Storing the Scale Parameter¶
Next, we add the __init__ method to accept our scale parameter and pass the environment through to the parent class (Wrapper):
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 | |
- The
*marker forces parameters after it (scale) to be keyword-only. It's not strictly required, but it's the recommended convention - it stops users from accidentally passing the parameter positionally whereenvis expected, which would break when the wrapper is used through themake()methods.
Step 3: Implementing reset()¶
This one's nice and easy. ScaleReward doesn't change anything about the reset path, so we delegate it straight to the inner environment:
| Python | |
|---|---|
1 2 | |
Step 4: Implementing step()¶
Lastly, our step method. Like the reset() method, we can delegate most of the logic to the inner environment and just unpack the step values.
Then, we scale the reward and return the new value:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 | |
The pattern here is pretty common across most wrappers:
- Unpack first to get full access to the inner step's return values
- Transform what you need — in this case, just the
reward * self._scale - Pass everything else through unchanged —
obs,new_state,done, andinfoall flow straight through
And that's the whole wrapper! This is far too simple to use in a production setting but gives you an insight into the key fundamentals of wrapper creation.
Pass-through: Noteworthy Additions¶
There are a few additional things worth noting:
- Type parameters — For a pure pass-through wrapper, you typically leave all four type parameters the same (
Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]) so the wrapper inherits whatever the inner environment uses. resetandstepare both abstract — you must implement both methods, even if one just delegates (likeresetdoes here).observation_space/action_spacedelegate automatically — you only need to override them if the wrapper changes their shape, dtype, or bounds (e.g.,GrayscaleObservationdrops the channel dim).- Keyword-only parameter —
scaleuses*to force it as a keyword-only parameter. It's not strictly required, but it's the recommended convention since it stops users from accidentally passing the parameter positionally and breaking it for themake()methods.
Stateful Wrapper¶
Full Code
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |
Now let's tackle the harder case!
We'll build MaxReward — a wrapper that tracks the maximum reward seen so far in the current episode and exposes it via info["max_reward"].
We can break this down into five key steps:
- Defining the outer state
- Declaring the class
- Setting up
__init__ - Implementing
reset() - Implementing
step()
Like before, we'll tackle them one at a time.
Step 1: Define the Outer State¶
Stateful wrappers must use a new @chex.dataclass that extends EnvState and wraps the inner state in an env_state field.
We also use Generic[InnerStateT] so other wrappers/environments keep their inner state type visible to our custom wrapper:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 | |
Step 2: Declaring the Class¶
Next, we can declare the wrapper itself. Stateful wrappers take five generic type parameters instead of four — the extra one (InnerStateT) tells the framework what the inner env's state type is:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | |
- Observation space type (inherits from inner environment)
- Action space type (inherits from inner environment)
- Outer state type, pinned to our custom
MaxRewardStateand generic over the inner state - Config type (inherits from inner environment)
- Inner state type, allowing the inner environment's state to plug in cleanly
The key things to remember here: keep InnerStateT parametric for IDE support, and pin your custom state class to the StateT generic position.
Step 3: Setting up __init__¶
All we do here is accept the environment in __init__ and delegate it to the parent class:
| Python | |
|---|---|
1 2 | |
No extra parameters or logic needed! If, for example, you wanted a MaxReward(threshold=0.5) variant, this is where you'd add it (with the same * keyword-only marker we used in ScaleReward).
Step 4: Implementing reset()¶
Now onto the interesting part. On reset, we need to:
- Reset the inner environment to get a fresh inner state
- Build a new
MaxRewardStatethat wraps the inner state and initialises our running max to-inf
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
- Forward the base fields (
rng/step/done) directly from the inner state. This is what letsVecEnv's auto-reset still see the rightdoneflag without having to unwrap our outer state. - Store the inner state verbatim under
env_stateso we can pass it back to the inner env'sstep()later.
Simple enough!
Step 5: Implementing step()¶
Full Method Code
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | |
Lastly, the real fun begins .
Firstly, we step the inner environment using the inner state — not the outer state, because the inner environment doesn't know our outer state exists:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | |
Then, we compute the new running max from the latest reward, calculate the reset max for when the episode ends, and store the running max in info:
| Python | |
|---|---|
1 2 3 4 5 6 | |
Finally, we build a new MaxRewardState and return the step values:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 | |
That's it! Stateful wrapper done!
Stateful: Noteworthy Additions¶
Like our pass-through wrappers, there are a few noteworthy additions to be aware of:
- Type parameters — Stateful wrappers require five type parameters (
StatefulWrapper[ObsSpaceT, ActSpaceT, OuterState, ConfigT, InnerState]). The third one pins your wrapper's outer state type, while the fifth stays parametric so the inner state still has IDE support. - Unwrapping before stepping — call
self._env.step(state.env_state, action), notself._env.step(state, action). The inner environment doesn't know about your outer state. - Copy base fields explicitly — on both
resetandstep,rng/step/donecome from the innerenv_state, not the old wrapper state. This is how the auto-reset signal reaches the framework. - Handle episode boundaries — when
done=True, your counters should reset. Here we usejnp.where(done, -inf, new_max)so the next episode starts fresh. This is the stateful wrapper equivalent of whatVecEnvalready does automatically for the base state.
Overriding Spaces¶
Both types of wrapper delegate observation_space and action_space to the inner environment by default. You only ever need to override them when your wrapper actually changes the shape, dtype, or bounds of the data flowing through it.
For example, with the GrayscaleObservation wrapper, we drop the channel dim from (H, W, 3) down to (H, W) and return a new Box space:
| Python | |
|---|---|
1 2 3 4 5 | |
The action_space stays untouched, so there's no need to override it.
The same pattern applies to stateful wrappers — override only the spaces you actually change and let the rest delegate to the inner environment.
Using Your Wrapper¶
Once written, custom wrappers integrate seamlessly with the other built-in ones. So, you can use them in the same way as mentioned in the Available Wrappers tutorial:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 | |
No custom integration needed!
Common Pitfalls¶
Like with all our tutorials, be wary of the following "gotchas":
- Remember to copy
rng/step/doneto the outer state.VecEnv's auto-reset readsstate.doneoff the outer state, notstate.env_state.done. If you don't copy the innerdoneflag forward on every step, the outer flag staysFalseforever and your episodes never auto-reset. - Calling
self._env.step(state, ...)instead ofself._env.step(state.env_state, ...). This will raise aTypeErrorbecause the inner environment doesn't understand your outer state dataclass. - Using Python scalars for episode-lifetime counters. These become static at trace time and will break your code. Always use
jnp.int32(value)/jnp.float32(value)instead. - Skipping
Generic[InnerStateT]. The wrapper still works, but inner wrappers/environments lose type safety onenv_stateand IDE autocomplete drops back toEnvState. - Skipping the
*keyword-only marker on parameterised args. If a wrapper parameter is positional, users can accidentally pass it whereenvis expected — breaking the wrapper when it's used through themake()methods. Always usedef __init__(self, env, *, param=...)when adding new parameters to your wrappers.
Recap¶
And that's a wrap! (pun intended
)
To recap:
Wrapperis for pass-through behaviour;StatefulWrapperis for when you need to remember something across steps.- Stateful wrapper states use
Generic[InnerStateT], hold anenv_statefield, and forwardrng/step/donefrom the inner state. - Use keyword-only parameters on
__init__so users can call your wrapper safely through themake()methods. - Override
observation_space/action_spaceonly when you need to modify them. - Always reset stateful counters on
done=Trueso the next episode starts fresh.