Vectorised¶
A batched wrapper that runs many copies of a single environment in parallel on one accelerator via jax.vmap.
envrax.vec_env.VecEnv
¶
Bases: BatchedEnv, Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Wraps any JaxEnv to operate over a batch of environments simultaneously.
Canonical BatchedEnv implementation: num_envs independent copies of
one JaxEnv stepped in parallel via jax.vmap, with per-slot auto-reset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Single-instance environment to vectorise |
required |
num_envs
|
int
|
Number of parallel environments ( |
required |
Source code in envrax/vec_env.py
| Python | |
|---|---|
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | |
n_slots
property
¶
Number of parallel slots (= num_envs).
name
property
¶
Inner environment's name. Used as the default key by MultiVecEnv.
Picks up wrapper delegation, so VecEnv(JitWrapper(BallEnv()))
still reports "BallEnv" rather than "JitWrapper".
Returns:
| Name | Type | Description |
|---|---|---|
name |
str
|
The wrapped environment's |
config
property
¶
Single environment configuration.
single_observation_space
property
¶
Observation space of a single inner environment.
single_action_space
property
¶
Action space of a single inner environment.
observation_space
property
¶
Batched observation space with a leading num_envs dimension.
action_space
property
¶
Batched action space with a leading num_envs dimension.
reset(rng)
¶
Reset all num_envs environments with independent random starts.
All returned arrays have a leading batch dimension B = num_envs,
e.g. observations of shape (B, *obs_shape).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Stacked first observations |
states |
EnvState
|
Batched environment states |
Source code in envrax/vec_env.py
| Python | |
|---|---|
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
step(state, actions)
¶
Advance all environments by one step simultaneously.
Each environment independently auto-resets when its episode ends.
All inputs and outputs have a leading batch dimension B = num_envs:
- Observations:
(B, *obs_shape) - Discrete actions:
(B,)— one int per env - Continuous actions:
(B, *action_shape)— one vector per env - Rewards / dones:
(B,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
EnvState
|
Batched environment states |
required |
actions
|
Array
|
One action per environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observations after the step |
new_states |
EnvState
|
Updated batched states |
rewards |
Array
|
Per-environment rewards |
dones |
Array
|
Per-environment terminal flags |
infos |
Dict[str, Any]
|
Batched info dict |
Source code in envrax/vec_env.py
| Python | |
|---|---|
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |
slot_state(state, slot_idx)
¶
Extract the state pytree for a single slot from the batched state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
EnvState
|
Batched environment state |
required |
slot_idx
|
int
|
Slot index in |
required |
Returns:
| Name | Type | Description |
|---|---|---|
single_state |
EnvState
|
Unbatched state pytree for the chosen slot. |
Source code in envrax/vec_env.py
| Python | |
|---|---|
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
render_slot(state, slot_idx)
¶
Render a single environment from the batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
EnvState
|
Batched environment state |
required |
slot_idx
|
int
|
Slot index in |
required |
Returns:
| Name | Type | Description |
|---|---|---|
frame |
ndarray
|
uint8 RGB array of shape |
Source code in envrax/vec_env.py
| Python | |
|---|---|
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | |
compile(cache_dir=DEFAULT_CACHE_DIR)
¶
Trigger XLA compilation by running dummy reset + step.
Runs once with done=False (typical path) and once with done=True
on the first slot to warm both branches of the auto-reset path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cache_dir
|
Path | str | None
|
XLA cache directory. Defaults to |
DEFAULT_CACHE_DIR
|
Source code in envrax/vec_env.py
| Python | |
|---|---|
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |