Pass-through Wrappers¶
Stateless wrappers that transform observations, actions, or rewards without carrying any state between steps.
envrax.wrappers.jit_wrapper.JitWrapper
¶
Bases: Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Wrap a JaxEnv so that reset and step are compiled with
jax.jit on construction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Environment to wrap. |
required |
cache_dir
|
Path | str | None
|
Directory for the persistent XLA compilation cache.
Defaults to |
required |
pre_warm
|
bool
|
Run a dummy |
required |
Source code in envrax/wrappers/jit_wrapper.py
| Python | |
|---|---|
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | |
compile()
¶
Trigger XLA compilation by running a dummy reset + step.
Safe to call multiple times — subsequent calls are near-instant because JAX caches the compiled kernels in memory.
Source code in envrax/wrappers/jit_wrapper.py
| Python | |
|---|---|
46 47 48 49 50 51 52 53 54 55 | |
envrax.wrappers.clip_reward.ClipReward
¶
Bases: Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Clip rewards to the sign of the reward: {−1, 0, +1}.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment to wrap. |
required |
Source code in envrax/wrappers/clip_reward.py
| Python | |
|---|---|
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
step(state, action)
¶
Advance the environment by one step and clip the reward to {−1, 0, +1}.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observation from the inner step |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Reward clipped to sign: |
done |
Array
|
Terminal flag from the inner step |
info |
Dict[str, Any]
|
Info dict from the inner step |
Source code in envrax/wrappers/clip_reward.py
| Python | |
|---|---|
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
envrax.wrappers.discount.EpisodeDiscount
¶
Bases: Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Convert the boolean done signal to a float32 episode discount.
The 4th return value of step() changes from bool to float32:
1.0 while the episode is running, 0.0 on termination.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment to wrap. |
required |
Source code in envrax/wrappers/discount.py
| Python | |
|---|---|
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
step(state, action)
¶
Advance the environment and return a float32 discount instead of done.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observation from the inner step |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Reward from the inner step (unchanged) |
discount |
Array
|
|
info |
Dict[str, Any]
|
Info dict from the inner step |
Source code in envrax/wrappers/discount.py
| Python | |
|---|---|
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
envrax.wrappers.expand_dims.ExpandDims
¶
Bases: Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Add a trailing size-1 dimension to reward and done.
Transforms scalar outputs from step() so that reward and done
have shape (..., 1) instead of (...).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment to wrap. |
required |
Source code in envrax/wrappers/expand_dims.py
| Python | |
|---|---|
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | |
step(state, action)
¶
Advance the environment and expand reward and done.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observation after the step (unchanged) |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Reward with a trailing size-1 dimension |
done |
Array
|
Terminal flag with a trailing size-1 dimension |
info |
Dict[str, Any]
|
Auxiliary info dict (unchanged) |
Source code in envrax/wrappers/expand_dims.py
| Python | |
|---|---|
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | |
envrax.wrappers.grayscale.GrayscaleObservation
¶
Bases: Wrapper[Box, ActSpaceT, StateT, ConfigT]
Convert RGB observations to grayscale using the NTSC luminance formula.
Wraps any environment whose reset / step return uint8[H, W, 3]
observations and converts them to uint8[H, W].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment to wrap. Must have a |
required |
Source code in envrax/wrappers/grayscale.py
| Python | |
|---|---|
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
reset(rng)
¶
Reset the inner environment and convert the observation to grayscale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Grayscale observation |
state |
StateT
|
Inner environment state |
Source code in envrax/wrappers/grayscale.py
| Python | |
|---|---|
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | |
step(state, action)
¶
Step the inner environment and convert the observation to grayscale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Grayscale observation |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Reward from the inner step |
done |
Array
|
Terminal flag from the inner step |
info |
Dict[str, Any]
|
Info dict from the inner step |
Source code in envrax/wrappers/grayscale.py
| Python | |
|---|---|
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
envrax.wrappers.normalize_obs.NormalizeObservation
¶
Bases: Wrapper[Box, ActSpaceT, StateT, ConfigT]
Normalises pixel observations from uint8 [0, 255] to float32 [0, 1].
Divides observations by 255.0 and casts to float32.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Environment to wrap. Must have a |
required |
Source code in envrax/wrappers/normalize_obs.py
| Python | |
|---|---|
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | |
reset(rng)
¶
Reset and return a normalised initial observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Normalised observation in |
state |
StateT
|
Inner environment state |
Source code in envrax/wrappers/normalize_obs.py
| Python | |
|---|---|
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | |
step(state, action)
¶
Step and return a normalised observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Normalised observation in |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Step reward |
done |
Array
|
Terminal flag |
info |
Dict[str, Any]
|
Environment metadata |
Source code in envrax/wrappers/normalize_obs.py
| Python | |
|---|---|
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
envrax.wrappers.resize.ResizeObservation
¶
Bases: Wrapper[Box, ActSpaceT, StateT, ConfigT]
Resize observations to (h, w) using bilinear interpolation.
Handles both:
- Grayscale —
uint8[H, W]→uint8[h, w] - RGB —
uint8[H, W, C]→uint8[h, w, C]
The channel dimension is preserved automatically; no pre-processing step
is required. For DQN-style pipelines, apply GrayscaleObservation first
so the output is uint8[h, w] before stacking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment returning |
required |
h
|
int
|
Output height in pixels. Default is |
required |
w
|
int
|
Output width in pixels. Default is |
required |
Source code in envrax/wrappers/resize.py
| Python | |
|---|---|
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
reset(rng)
¶
Reset the inner environment and resize the observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Resized observation |
state |
StateT
|
Inner environment state |
Source code in envrax/wrappers/resize.py
| Python | |
|---|---|
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | |
step(state, action)
¶
Step the inner environment and resize the observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Resized observation |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Reward from the inner step |
done |
Array
|
Terminal flag from the inner step |
info |
Dict[str, Any]
|
Info dict from the inner step |
Source code in envrax/wrappers/resize.py
| Python | |
|---|---|
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | |
envrax.wrappers.record_video.RecordVideo
¶
Bases: Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Save episode frames to MP4 based on configurable triggers.
Not JIT/vmap-compatible. Intended for evaluation, logging, and training visualisation.
Three optional triggers control when recording is active. They are
OR'd together — if any trigger returns True, that episode is
recorded. When no triggers are provided, every episode is recorded.
Each completed recording is written to
<output_dir>/episode_<NNNN>.mp4 via imageio.
Requires imageio with the ffmpeg plugin
(pip install "imageio[ffmpeg]").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment to wrap that has a |
required |
output_dir
|
str | Path
|
Directory where MP4 files are saved. Created automatically if
it does not exist. Default is |
required |
fps
|
int
|
Frames per second for the saved video. Default is |
required |
episode_trigger
|
Callable[[int], bool]
|
Called with the episode count at each |
required |
step_trigger
|
Callable[[int], bool]
|
Called with the global step count at each |
required |
recording_trigger
|
Callable[[], bool]
|
Zero-arg callable checked at each |
required |
Raises:
| Name | Type | Description |
|---|---|---|
render_missing
|
TypeError
|
If the unwrapped environment does not implement |
Source code in envrax/wrappers/record_video.py
| Python | |
|---|---|
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | |
recording
property
¶
Whether the current episode is being recorded.
reset(rng)
¶
Reset the environment and optionally begin a new recording.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
First observation |
state |
StateT
|
Initial environment state |
Source code in envrax/wrappers/record_video.py
| Python | |
|---|---|
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
step(state, action)
¶
Advance the environment by one step and record the frame if active.
If a step_trigger is provided and fires, recording starts
mid-episode and continues until the episode ends.
Flushes accumulated frames to an MP4 file when done is True.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateT
|
Current environment state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observation after the step |
new_state |
StateT
|
Updated environment state |
reward |
Array
|
Reward for this step |
done |
Array
|
|
info |
Dict[str, Any]
|
Pass-through info dict from the inner environment |
Source code in envrax/wrappers/record_video.py
| Python | |
|---|---|
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |