Vectorised¶
A batched wrapper that runs many copies of a single environment in parallel on one accelerator via jax.vmap.
envrax.vec_env.VecEnv
¶
Bases: Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]
Wraps any JaxEnv to operate over a batch of environments simultaneously.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Single-instance environment to vectorise |
required |
num_envs
|
int
|
Number of parallel environments |
required |
Source code in envrax/vec_env.py
| Python | |
|---|---|
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | |
config
property
¶
Single environment configuration.
single_observation_space
property
¶
Observation space of a single inner environment.
single_action_space
property
¶
Action space of a single inner environment.
observation_space
property
¶
Batched observation space with a leading num_envs dimension.
action_space
property
¶
Batched action space with a leading num_envs dimension.
reset(rng)
¶
Reset all num_envs environments with independent random starts.
All returned arrays have a leading batch dimension B = num_envs,
e.g. observations of shape (B, *obs_shape).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Stacked first observations |
states |
EnvState
|
Batched environment states |
Source code in envrax/vec_env.py
| Python | |
|---|---|
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
step(state, actions)
¶
Advance all environments by one step simultaneously.
Each environment independently auto-resets when its episode ends.
All inputs and outputs have a leading batch dimension B = num_envs:
- Observations:
(B, *obs_shape) - Discrete actions:
(B,)— one int per env - Continuous actions:
(B, *action_shape)— one vector per env - Rewards / dones:
(B,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
EnvState
|
Batched environment states |
required |
actions
|
Array
|
One action per environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observations after the step |
new_states |
EnvState
|
Updated batched states |
rewards |
Array
|
Per-environment rewards |
dones |
Array
|
Per-environment terminal flags |
infos |
Dict[str, Any]
|
Batched info dict |
Source code in envrax/vec_env.py
| Python | |
|---|---|
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | |
render(state, *, index=0)
¶
Render a single environment from the batch.
Extracts the state at index from the batched state pytree and
delegates to the inner env's render().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
EnvState
|
Batched environment state |
required |
index
|
int
|
Which environment in the batch to render. Default is |
0
|
Returns:
| Name | Type | Description |
|---|---|---|
frame |
ndarray
|
uint8 RGB array of shape |
Source code in envrax/vec_env.py
| Python | |
|---|---|
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
compile(cache_dir=DEFAULT_CACHE_DIR)
¶
Trigger XLA compilation by running a dummy reset + step.
Useful when construction and compilation should be separate phases. Safe to call multiple times.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cache_dir
|
Path | str | None
|
XLA cache directory. Defaults to |
DEFAULT_CACHE_DIR
|
Source code in envrax/vec_env.py
| Python | |
|---|---|
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | |