Batched Environment¶
Base class for any environment that produces N independent agent results per step.
envrax.batched_env.BatchedEnv
¶
Bases: ABC
Base class for an env that produces n_slots independent agent results per step.
Implementations choose their batching strategy — jax.vmap over a single
JaxEnv, a composite multi-agent scene, threaded backends, or any other
approach. All implementations expose the same shape contract:
- observations
(n_slots, *obs_shape) - actions
(n_slots, *action_shape) - rewards
(n_slots,) - dones
(n_slots,)
State is an implementation-specific pytree carried opaquely by the caller.
Source code in envrax/batched_env.py
| Python | |
|---|---|
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
name
property
¶
Default key used by MultiVecEnv when keys aren't supplied.
Subclasses should override with something meaningful (e.g. the
wrapped env's class name). Falls back to this BatchedEnv subclass
name when not overridden.
Returns:
| Name | Type | Description |
|---|---|---|
name |
str
|
Short identifier for this batched env. |
single_observation_space
abstractmethod
property
¶
Observation space of a single slot (unbatched).
single_action_space
abstractmethod
property
¶
Action space of a single slot (unbatched).
reset(rng)
abstractmethod
¶
Reset all n_slots independent agents.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Initial observations, shape |
state |
Any
|
Batched state pytree |
Source code in envrax/batched_env.py
| Python | |
|---|---|
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | |
step(state, actions)
abstractmethod
¶
Step all n_slots agents independently with per-slot auto-reset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Batched state from a previous reset or step |
required |
actions
|
Array
|
Actions, shape |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Observations after the step, shape |
new_state |
Any
|
Updated batched state |
reward |
Array
|
Per-slot rewards, shape |
done |
Array
|
Per-slot terminal flags, shape |
info |
Dict[str, Any]
|
Batched info dict |
Source code in envrax/batched_env.py
| Python | |
|---|---|
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | |
slot_state(state, slot_idx)
abstractmethod
¶
Extract the single-slot state pytree for one agent.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Batched state |
required |
slot_idx
|
int
|
Slot index in |
required |
Returns:
| Name | Type | Description |
|---|---|---|
single_state |
Any
|
Pytree of the same structure as |
Source code in envrax/batched_env.py
| Python | |
|---|---|
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |
render_slot(state, slot_idx)
abstractmethod
¶
Render a single slot from the batched state as an RGB frame.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Batched state |
required |
slot_idx
|
int
|
Slot index in |
required |
Returns:
| Name | Type | Description |
|---|---|---|
frame |
ndarray
|
uint8 RGB array of shape |
Source code in envrax/batched_env.py
| Python | |
|---|---|
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | |
compile(cache_dir=None)
abstractmethod
¶
Trigger XLA compilation by running dummy reset + step calls.
Implementations should also warm any conditional branches (e.g. the reset path of auto-reset logic) so the persistent cache covers them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cache_dir
|
Path | str | None
|
XLA cache directory. Implementations may default to a stable project-relative path. |
None
|
Source code in envrax/batched_env.py
| Python | |
|---|---|
141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |