Skip to content

Multi-Environment

Classes for composing several different environments (and their vectorised variants) under a single unified handle.

MultiEnv holds a list of JaxEnv instances and dispatches via Python iteration. MultiVecEnv holds a dict of BatchedEnv instances (keyed by env name, auto-derived from a list if you prefer) and dispatches inside one jax.jit boundary.

envrax.multi_env.MultiEnv

Container for multiple JaxEnv instances keyed by env name.

Holds one inner env per key and dispatches reset/step via a Python loop. No outer jax.jit boundary is added — each inner env keeps its own compile cycle (typically via JitWrapper). Use MultiVecEnv if you need a single jitted dispatch over batched envs.

Accepts either a list (keys derived from each env's name via _auto_key) or a dict (used as-is for explicit control).

Parameters:

Name Type Description Default
envs List[JaxEnv] | Dict[str, JaxEnv]

Envs to wrap. When a list, keys are derived from env.name with suffixes on duplicates. When a dict, keys are used verbatim. Iteration order is preserved.

required
Source code in envrax/multi_env.py
Python
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
class MultiEnv:
    """
    Container for multiple `JaxEnv` instances keyed by env name.

    Holds one inner env per key and dispatches `reset`/`step` via a Python
    loop. No outer `jax.jit` boundary is added — each inner env keeps its
    own compile cycle (typically via `JitWrapper`). Use `MultiVecEnv` if
    you need a single jitted dispatch over batched envs.

    Accepts either a list (keys derived from each env's `name` via
    `_auto_key`) or a dict (used as-is for explicit control).

    Parameters
    ----------
    envs : List[JaxEnv] | Dict[str, JaxEnv]
        Envs to wrap. When a list, keys are derived from `env.name` with
        suffixes on duplicates. When a dict, keys are used verbatim.
        Iteration order is preserved.
    """

    def __init__(
        self,
        envs: List[JaxEnv] | Dict[str, JaxEnv],
    ) -> None:
        if not envs:
            raise ValueError("MultiEnv requires at least one environment.")

        if isinstance(envs, dict):
            envs_dict = dict(envs)
        else:
            envs_dict = _auto_key(list(envs))

        self._envs: Dict[str, JaxEnv] = envs_dict
        self._keys: List[str] = list(self._envs.keys())

    @property
    def envs(self) -> Dict[str, JaxEnv]:
        """The inner `JaxEnv` instances keyed by env name."""
        return self._envs

    @property
    def env_keys(self) -> List[str]:
        """Ordered list of env-type keys."""
        return list(self._keys)

    @property
    def n_envs(self) -> int:
        """Number of environments (`M`)."""
        return len(self._envs)

    @property
    def observation_spaces(self) -> Dict[str, Space]:
        """Per-env observation spaces."""
        return {k: e.observation_space for k, e in self._envs.items()}

    @property
    def action_spaces(self) -> Dict[str, Space]:
        """Per-env action spaces."""
        return {k: e.action_space for k, e in self._envs.items()}

    @property
    def observation_shapes(self) -> Dict[str, Tuple[int, ...]]:
        """Per-env observation shapes."""
        return {k: s.shape for k, s in self.observation_spaces.items()}

    @property
    def action_shapes(self) -> Dict[str, Tuple[int, ...]]:
        """Per-env action shapes."""
        return {k: s.shape for k, s in self.action_spaces.items()}

    @property
    def observation_sizes(self) -> Dict[str, int]:
        """Per-env flat observation element counts (`prod(shape)`)."""
        return {k: int(prod(s.shape)) for k, s in self.observation_spaces.items()}

    @property
    def action_sizes(self) -> Dict[str, int]:
        """Per-env flat action element counts (`prod(shape)`)."""
        return {k: int(prod(s.shape)) for k, s in self.action_spaces.items()}

    @property
    def observation_dtypes(self) -> Dict[str, Type]:
        """Per-env observation dtypes."""
        return {k: s.dtype for k, s in self.observation_spaces.items()}

    @property
    def action_dtypes(self) -> Dict[str, Type]:
        """Per-env action dtypes."""
        return {k: s.dtype for k, s in self.action_spaces.items()}

    def pad_dims(self) -> Tuple[int, int]:
        """
        Return `(max_action_size, max_observation_size)` across envs.

        Returns
        -------
        action : int
            Largest flat action size.
        observation : int
            Largest flat observation size.
        """
        return (
            max(self.action_sizes.values()),
            max(self.observation_sizes.values()),
        )

    def reset(
        self, rng: chex.PRNGKey
    ) -> Tuple[Dict[str, jax.Array], Dict[str, EnvState]]:
        """
        Reset all environments with independent PRNG sub-keys.

        Splits `rng` deterministically. Same master key produces the same
        per-env keys for full reproducibility.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : Dict[str, jax.Array]
            Per-env initial observations.
        states : Dict[str, EnvState]
            Per-env initial states.
        """
        keys = jax.random.split(rng, len(self._keys))
        obs: Dict[str, jax.Array] = {}
        states: Dict[str, EnvState] = {}

        for i, key in enumerate(self._keys):
            o, s = self._envs[key].reset(keys[i])
            obs[key] = o
            states[key] = s

        return obs, states

    def step(
        self,
        states: Dict[str, EnvState],
        actions: Dict[str, jax.Array],
    ) -> Tuple[
        Dict[str, jax.Array],
        Dict[str, EnvState],
        Dict[str, jax.Array],
        Dict[str, jax.Array],
        Dict[str, Dict[str, Any]],
    ]:
        """
        Step all environments simultaneously.

        Parameters
        ----------
        states : Dict[str, EnvState]
            Per-env states from a previous reset or step.
        actions : Dict[str, jax.Array]
            Per-env actions matching each env's action space.

        Returns
        -------
        obs : Dict[str, jax.Array]
            Per-env observations after the step.
        new_states : Dict[str, EnvState]
            Per-env updated states.
        rewards : Dict[str, jax.Array]
            Per-env scalar rewards.
        dones : Dict[str, jax.Array]
            Per-env terminal flags.
        infos : Dict[str, Dict[str, Any]]
            Per-env info dicts.

        Raises
        ------
        key_mismatch : ValueError
            If `states` or `actions` keys do not match `env_keys`.
        """
        if set(states.keys()) != set(self._keys):
            raise ValueError(
                f"MultiEnv.step: `states` keys {sorted(states.keys())} "
                f"do not match env keys {sorted(self._keys)}."
            )

        if set(actions.keys()) != set(self._keys):
            raise ValueError(
                f"MultiEnv.step: `actions` keys {sorted(actions.keys())} "
                f"do not match env keys {sorted(self._keys)}."
            )

        obs: Dict[str, jax.Array] = {}
        new_states: Dict[str, EnvState] = {}
        rewards: Dict[str, jax.Array] = {}
        dones: Dict[str, jax.Array] = {}
        infos: Dict[str, Dict[str, Any]] = {}

        for key in self._keys:
            o, s, r, d, info = self._envs[key].step(states[key], actions[key])
            obs[key] = o
            new_states[key] = s
            rewards[key] = r
            dones[key] = d
            infos[key] = info

        return obs, new_states, rewards, dones, infos

    def compile(self, *, progress: bool = True) -> None:
        """
        Trigger XLA compilation for all JIT-wrapped environments.

        Calls `compile()` on each inner env that is a `JitWrapper`.
        Environments without JIT wrapping are silently skipped.

        Parameters
        ----------
        progress : bool (optional)
            Show a `tqdm` progress bar. Default is `True`.
        """
        jit_envs = [
            (k, env) for k, env in self._envs.items() if isinstance(env, JitWrapper)
        ]
        if not jit_envs:
            return

        it = (
            tqdm(jit_envs, desc="Compiling envs", unit="env") if progress else jit_envs
        )
        for _, env in it:
            env.compile()

    def __len__(self) -> int:
        return self.n_envs

    def __repr__(self) -> str:
        group_info = ", ".join(f"{k}={type(e).__name__}" for k, e in self._envs.items())
        return f"MultiEnv({{{group_info}}}, n_envs={self.n_envs})"

envs property

The inner JaxEnv instances keyed by env name.

env_keys property

Ordered list of env-type keys.

n_envs property

Number of environments (M).

observation_spaces property

Per-env observation spaces.

action_spaces property

Per-env action spaces.

observation_shapes property

Per-env observation shapes.

action_shapes property

Per-env action shapes.

observation_sizes property

Per-env flat observation element counts (prod(shape)).

action_sizes property

Per-env flat action element counts (prod(shape)).

observation_dtypes property

Per-env observation dtypes.

action_dtypes property

Per-env action dtypes.

pad_dims()

Return (max_action_size, max_observation_size) across envs.

Returns:

Name Type Description
action int

Largest flat action size.

observation int

Largest flat observation size.

Source code in envrax/multi_env.py
Python
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def pad_dims(self) -> Tuple[int, int]:
    """
    Return `(max_action_size, max_observation_size)` across envs.

    Returns
    -------
    action : int
        Largest flat action size.
    observation : int
        Largest flat observation size.
    """
    return (
        max(self.action_sizes.values()),
        max(self.observation_sizes.values()),
    )

reset(rng)

Reset all environments with independent PRNG sub-keys.

Splits rng deterministically. Same master key produces the same per-env keys for full reproducibility.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Dict[str, Array]

Per-env initial observations.

states Dict[str, EnvState]

Per-env initial states.

Source code in envrax/multi_env.py
Python
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def reset(
    self, rng: chex.PRNGKey
) -> Tuple[Dict[str, jax.Array], Dict[str, EnvState]]:
    """
    Reset all environments with independent PRNG sub-keys.

    Splits `rng` deterministically. Same master key produces the same
    per-env keys for full reproducibility.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : Dict[str, jax.Array]
        Per-env initial observations.
    states : Dict[str, EnvState]
        Per-env initial states.
    """
    keys = jax.random.split(rng, len(self._keys))
    obs: Dict[str, jax.Array] = {}
    states: Dict[str, EnvState] = {}

    for i, key in enumerate(self._keys):
        o, s = self._envs[key].reset(keys[i])
        obs[key] = o
        states[key] = s

    return obs, states

step(states, actions)

Step all environments simultaneously.

Parameters:

Name Type Description Default
states Dict[str, EnvState]

Per-env states from a previous reset or step.

required
actions Dict[str, Array]

Per-env actions matching each env's action space.

required

Returns:

Name Type Description
obs Dict[str, Array]

Per-env observations after the step.

new_states Dict[str, EnvState]

Per-env updated states.

rewards Dict[str, Array]

Per-env scalar rewards.

dones Dict[str, Array]

Per-env terminal flags.

infos Dict[str, Dict[str, Any]]

Per-env info dicts.

Raises:

Name Type Description
key_mismatch ValueError

If states or actions keys do not match env_keys.

Source code in envrax/multi_env.py
Python
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def step(
    self,
    states: Dict[str, EnvState],
    actions: Dict[str, jax.Array],
) -> Tuple[
    Dict[str, jax.Array],
    Dict[str, EnvState],
    Dict[str, jax.Array],
    Dict[str, jax.Array],
    Dict[str, Dict[str, Any]],
]:
    """
    Step all environments simultaneously.

    Parameters
    ----------
    states : Dict[str, EnvState]
        Per-env states from a previous reset or step.
    actions : Dict[str, jax.Array]
        Per-env actions matching each env's action space.

    Returns
    -------
    obs : Dict[str, jax.Array]
        Per-env observations after the step.
    new_states : Dict[str, EnvState]
        Per-env updated states.
    rewards : Dict[str, jax.Array]
        Per-env scalar rewards.
    dones : Dict[str, jax.Array]
        Per-env terminal flags.
    infos : Dict[str, Dict[str, Any]]
        Per-env info dicts.

    Raises
    ------
    key_mismatch : ValueError
        If `states` or `actions` keys do not match `env_keys`.
    """
    if set(states.keys()) != set(self._keys):
        raise ValueError(
            f"MultiEnv.step: `states` keys {sorted(states.keys())} "
            f"do not match env keys {sorted(self._keys)}."
        )

    if set(actions.keys()) != set(self._keys):
        raise ValueError(
            f"MultiEnv.step: `actions` keys {sorted(actions.keys())} "
            f"do not match env keys {sorted(self._keys)}."
        )

    obs: Dict[str, jax.Array] = {}
    new_states: Dict[str, EnvState] = {}
    rewards: Dict[str, jax.Array] = {}
    dones: Dict[str, jax.Array] = {}
    infos: Dict[str, Dict[str, Any]] = {}

    for key in self._keys:
        o, s, r, d, info = self._envs[key].step(states[key], actions[key])
        obs[key] = o
        new_states[key] = s
        rewards[key] = r
        dones[key] = d
        infos[key] = info

    return obs, new_states, rewards, dones, infos

compile(*, progress=True)

Trigger XLA compilation for all JIT-wrapped environments.

Calls compile() on each inner env that is a JitWrapper. Environments without JIT wrapping are silently skipped.

Parameters:

Name Type Description Default
progress bool

Show a tqdm progress bar. Default is True.

True
Source code in envrax/multi_env.py
Python
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def compile(self, *, progress: bool = True) -> None:
    """
    Trigger XLA compilation for all JIT-wrapped environments.

    Calls `compile()` on each inner env that is a `JitWrapper`.
    Environments without JIT wrapping are silently skipped.

    Parameters
    ----------
    progress : bool (optional)
        Show a `tqdm` progress bar. Default is `True`.
    """
    jit_envs = [
        (k, env) for k, env in self._envs.items() if isinstance(env, JitWrapper)
    ]
    if not jit_envs:
        return

    it = (
        tqdm(jit_envs, desc="Compiling envs", unit="env") if progress else jit_envs
    )
    for _, env in it:
        env.compile()

envrax.multi_vec_env.MultiVecEnv

JAX-native container for multiple BatchedEnv instances keyed by env name.

State is a dict-of-pytrees (Dict[str, chex.ArrayTree]). The cross-env-type dispatch runs as a Python loop at jax.jit trace time, producing one XLA computation per call that dispatches one inner-kernel per env type with no per-call Python overhead between them.

Accepts either a list (keys derived from each env's name via _auto_key) or a dict (used as-is for explicit control).

Parameters:

Name Type Description Default
envs List[BatchedEnv] | Dict[str, BatchedEnv]

Envs to wrap. When a list, keys are derived from env.name with suffixes on duplicates. When a dict, keys are used verbatim. Iteration order is preserved.

required
Source code in envrax/multi_vec_env.py
Python
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
class MultiVecEnv:
    """
    JAX-native container for multiple `BatchedEnv` instances keyed by env name.

    State is a dict-of-pytrees (`Dict[str, chex.ArrayTree]`). The
    cross-env-type dispatch runs as a Python loop at `jax.jit` trace time,
    producing one XLA computation per call that dispatches one inner-kernel
    per env type with no per-call Python overhead between them.

    Accepts either a list (keys derived from each env's `name` via
    `_auto_key`) or a dict (used as-is for explicit control).

    Parameters
    ----------
    envs : List[BatchedEnv] | Dict[str, BatchedEnv]
        Envs to wrap. When a list, keys are derived from `env.name` with
        suffixes on duplicates. When a dict, keys are used verbatim.
        Iteration order is preserved.
    """

    def __init__(
        self,
        envs: List[BatchedEnv] | Dict[str, BatchedEnv],
    ) -> None:
        if not envs:
            raise ValueError("MultiVecEnv requires at least one 'BatchedEnv'.")

        if isinstance(envs, dict):
            envs_dict = dict(envs)
        else:
            envs_dict = _auto_key(list(envs))

        self._envs: Dict[str, BatchedEnv] = envs_dict
        self._keys: List[str] = list(self._envs.keys())
        self._jit_reset = jax.jit(self._reset_impl)
        self._jit_step = jax.jit(self._step_impl)

    @property
    def envs(self) -> Dict[str, BatchedEnv]:
        """The inner `BatchedEnv` instances keyed by env name."""
        return self._envs

    @property
    def env_keys(self) -> List[str]:
        """Ordered list of env-type keys."""
        return list(self._keys)

    @property
    def n_envs(self) -> int:
        """Number of distinct env types (= number of `BatchedEnv` instances)."""
        return len(self._envs)

    @property
    def total_slots(self) -> int:
        """Total number of individual agent slots across all env types."""
        return sum(e.n_slots for e in self._envs.values())

    @property
    def slots_per_env(self) -> Dict[str, int]:
        """Per-env-type slot counts."""
        return {k: e.n_slots for k, e in self._envs.items()}

    @property
    def single_observation_spaces(self) -> Dict[str, Space]:
        """Per-env-type unbatched observation spaces."""
        return {k: e.single_observation_space for k, e in self._envs.items()}

    @property
    def single_action_spaces(self) -> Dict[str, Space]:
        """Per-env-type unbatched action spaces."""
        return {k: e.single_action_space for k, e in self._envs.items()}

    @property
    def single_observation_shapes(self) -> Dict[str, Tuple[int, ...]]:
        """Per-env-type unbatched observation shapes."""
        return {k: s.shape for k, s in self.single_observation_spaces.items()}

    @property
    def single_action_shapes(self) -> Dict[str, Tuple[int, ...]]:
        """Per-env-type unbatched action shapes."""
        return {k: s.shape for k, s in self.single_action_spaces.items()}

    @property
    def single_observation_sizes(self) -> Dict[str, int]:
        """Per-env-type flat unbatched observation element counts."""
        return {
            k: int(prod(s.shape)) for k, s in self.single_observation_spaces.items()
        }

    @property
    def single_action_sizes(self) -> Dict[str, int]:
        """Per-env-type flat unbatched action element counts."""
        return {k: int(prod(s.shape)) for k, s in self.single_action_spaces.items()}

    @property
    def single_observation_dtypes(self) -> Dict[str, Type]:
        """Per-env-type unbatched observation dtypes."""
        return {k: s.dtype for k, s in self.single_observation_spaces.items()}

    @property
    def single_action_dtypes(self) -> Dict[str, Type]:
        """Per-env-type unbatched action dtypes."""
        return {k: s.dtype for k, s in self.single_action_spaces.items()}

    def pad_dims(self) -> Tuple[int, int]:
        """
        Return `(max_action_size, max_observation_size)` across env types.

        Returns
        -------
        action : int
            Largest flat action size across all env types.
        observation : int
            Largest flat observation size across all env types.
        """
        return (
            max(self.single_action_sizes.values()),
            max(self.single_observation_sizes.values()),
        )

    def reset(
        self, rng: chex.PRNGKey
    ) -> Tuple[Dict[str, jax.Array], Dict[str, chex.ArrayTree]]:
        """
        Reset all env types with independent PRNG sub-keys.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : Dict[str, jax.Array]
            Per-env-type batched observations.
        states : Dict[str, chex.ArrayTree]
            Per-env-type batched state pytrees.
        """
        return self._jit_reset(rng)

    def _reset_impl(
        self, rng: chex.PRNGKey
    ) -> Tuple[Dict[str, jax.Array], Dict[str, chex.ArrayTree]]:
        """
        Unjitted body of `reset`. Wrapped by `self._jit_reset` in `__init__`.

        Splits `rng` into one sub-key per env type and traces each inner
        env's `reset` into the same XLA computation at jit time.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : Dict[str, jax.Array]
            Per-env-type batched observations.
        states : Dict[str, chex.ArrayTree]
            Per-env-type batched state pytrees.
        """
        keys = jax.random.split(rng, len(self._keys))
        obs: Dict[str, jax.Array] = {}
        states: Dict[str, chex.ArrayTree] = {}

        for i, key in enumerate(self._keys):
            o, s = self._envs[key].reset(keys[i])
            obs[key] = o
            states[key] = s

        return obs, states

    def step(
        self,
        states: Dict[str, chex.ArrayTree],
        actions: Dict[str, jax.Array],
    ) -> Tuple[
        Dict[str, jax.Array],
        Dict[str, chex.ArrayTree],
        Dict[str, jax.Array],
        Dict[str, jax.Array],
        Dict[str, Dict[str, Any]],
    ]:
        """
        Step all env types simultaneously.

        Parameters
        ----------
        states : Dict[str, chex.ArrayTree]
            Per-env-type batched states from a previous reset or step.
        actions : Dict[str, jax.Array]
            Per-env-type batched actions.

        Returns
        -------
        obs : Dict[str, jax.Array]
            Per-env-type batched observations after the step.
        new_states : Dict[str, chex.ArrayTree]
            Per-env-type updated batched states.
        rewards : Dict[str, jax.Array]
            Per-env-type batched rewards.
        dones : Dict[str, jax.Array]
            Per-env-type batched terminal flags.
        infos : Dict[str, Dict[str, Any]]
            Per-env-type batched info dicts.

        Raises
        ------
        key_mismatch : ValueError
            If `states` or `actions` keys do not match `env_keys`.
        """
        if set(states.keys()) != set(self._keys):
            raise ValueError(
                f"MultiVecEnv.step: `states` keys {sorted(states.keys())} "
                f"do not match env keys {sorted(self._keys)}."
            )

        if set(actions.keys()) != set(self._keys):
            raise ValueError(
                f"MultiVecEnv.step: `actions` keys {sorted(actions.keys())} "
                f"do not match env keys {sorted(self._keys)}."
            )

        return self._jit_step(states, actions)

    def _step_impl(
        self,
        states: Dict[str, chex.ArrayTree],
        actions: Dict[str, jax.Array],
    ) -> Tuple[
        Dict[str, jax.Array],
        Dict[str, chex.ArrayTree],
        Dict[str, jax.Array],
        Dict[str, jax.Array],
        Dict[str, Dict[str, Any]],
    ]:
        """
        Unjitted body of `step`. Wrapped by `self._jit_step` in `__init__`.

        Traces each inner env's `step` into the same XLA computation at jit
        time — the Python loop over `self._keys` unrolls at tracing, not at
        runtime.

        Parameters
        ----------
        states : Dict[str, chex.ArrayTree]
            Per-env-type batched states from a previous reset or step.
        actions : Dict[str, jax.Array]
            Per-env-type batched actions.

        Returns
        -------
        obs : Dict[str, jax.Array]
            Per-env-type batched observations after the step.
        new_states : Dict[str, chex.ArrayTree]
            Per-env-type updated batched states.
        rewards : Dict[str, jax.Array]
            Per-env-type batched rewards.
        dones : Dict[str, jax.Array]
            Per-env-type batched terminal flags.
        infos : Dict[str, Dict[str, Any]]
            Per-env-type batched info dicts.
        """
        obs: Dict[str, jax.Array] = {}
        new_states: Dict[str, chex.ArrayTree] = {}
        rewards: Dict[str, jax.Array] = {}
        dones: Dict[str, jax.Array] = {}
        infos: Dict[str, Dict[str, Any]] = {}

        for key in self._keys:
            o, s, r, d, info = self._envs[key].step(states[key], actions[key])
            obs[key] = o
            new_states[key] = s
            rewards[key] = r
            dones[key] = d
            infos[key] = info

        return obs, new_states, rewards, dones, infos

    def slot_state(
        self, states: Dict[str, chex.ArrayTree], key: str, slot_idx: int
    ) -> chex.ArrayTree:
        """
        Extract the single-slot state pytree for one agent.

        Parameters
        ----------
        states : Dict[str, chex.ArrayTree]
            Per-env-type batched states.
        key : str
            Env-type key in `env_keys`.
        slot_idx : int
            Slot index in `[0, slots_per_env[key])`.

        Returns
        -------
        single_state : chex.ArrayTree
            Unbatched state pytree for the chosen slot.
        """
        return self._envs[key].slot_state(states[key], slot_idx)

    def render_slot(
        self, states: Dict[str, chex.ArrayTree], key: str, slot_idx: int
    ) -> np.ndarray:
        """
        Render a single slot as an RGB frame.

        Parameters
        ----------
        states : Dict[str, chex.ArrayTree]
            Per-env-type batched states.
        key : str
            Env-type key in `env_keys`.
        slot_idx : int
            Slot index in `[0, slots_per_env[key])`.

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(H, W, 3)`.
        """
        return self._envs[key].render_slot(states[key], slot_idx)

    def compile(
        self,
        cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
        *,
        progress: bool = True,
    ) -> None:
        """
        Trigger XLA compilation for all inner envs and warm the multi-step jit.

        Parameters
        ----------
        cache_dir : Path | str | None (optional)
            XLA cache directory. Defaults to `<cwd>/.jax_cache`.
        progress : bool (optional)
            Show a `tqdm` progress bar. Default is `True`.
        """
        setup_cache(cache_dir)

        it = (
            tqdm(self._envs.items(), desc="Compiling batched envs", unit="env")
            if progress
            else self._envs.items()
        )
        for _, env in it:
            env.compile(cache_dir=cache_dir)

        _rng = jax.random.key(0)
        _obs, _states = self.reset(_rng)
        _action_keys = jax.random.split(_rng, len(self._keys))
        _dummy_actions = {
            k: jax.vmap(self._envs[k].single_action_space.sample)(
                jax.random.split(_action_keys[i], self._envs[k].n_slots)
            )
            for i, k in enumerate(self._keys)
        }
        self.step(_states, _dummy_actions)

    def __len__(self) -> int:
        return self.n_envs

    def __repr__(self) -> str:
        group_info = ", ".join(
            f"{k}*{e.n_slots}" for k, e in self._envs.items()
        )
        return f"MultiVecEnv({{{group_info}}}, total_slots={self.total_slots})"

envs property

The inner BatchedEnv instances keyed by env name.

env_keys property

Ordered list of env-type keys.

n_envs property

Number of distinct env types (= number of BatchedEnv instances).

total_slots property

Total number of individual agent slots across all env types.

slots_per_env property

Per-env-type slot counts.

single_observation_spaces property

Per-env-type unbatched observation spaces.

single_action_spaces property

Per-env-type unbatched action spaces.

single_observation_shapes property

Per-env-type unbatched observation shapes.

single_action_shapes property

Per-env-type unbatched action shapes.

single_observation_sizes property

Per-env-type flat unbatched observation element counts.

single_action_sizes property

Per-env-type flat unbatched action element counts.

single_observation_dtypes property

Per-env-type unbatched observation dtypes.

single_action_dtypes property

Per-env-type unbatched action dtypes.

pad_dims()

Return (max_action_size, max_observation_size) across env types.

Returns:

Name Type Description
action int

Largest flat action size across all env types.

observation int

Largest flat observation size across all env types.

Source code in envrax/multi_vec_env.py
Python
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def pad_dims(self) -> Tuple[int, int]:
    """
    Return `(max_action_size, max_observation_size)` across env types.

    Returns
    -------
    action : int
        Largest flat action size across all env types.
    observation : int
        Largest flat observation size across all env types.
    """
    return (
        max(self.single_action_sizes.values()),
        max(self.single_observation_sizes.values()),
    )

reset(rng)

Reset all env types with independent PRNG sub-keys.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Dict[str, Array]

Per-env-type batched observations.

states Dict[str, ArrayTree]

Per-env-type batched state pytrees.

Source code in envrax/multi_vec_env.py
Python
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def reset(
    self, rng: chex.PRNGKey
) -> Tuple[Dict[str, jax.Array], Dict[str, chex.ArrayTree]]:
    """
    Reset all env types with independent PRNG sub-keys.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : Dict[str, jax.Array]
        Per-env-type batched observations.
    states : Dict[str, chex.ArrayTree]
        Per-env-type batched state pytrees.
    """
    return self._jit_reset(rng)

step(states, actions)

Step all env types simultaneously.

Parameters:

Name Type Description Default
states Dict[str, ArrayTree]

Per-env-type batched states from a previous reset or step.

required
actions Dict[str, Array]

Per-env-type batched actions.

required

Returns:

Name Type Description
obs Dict[str, Array]

Per-env-type batched observations after the step.

new_states Dict[str, ArrayTree]

Per-env-type updated batched states.

rewards Dict[str, Array]

Per-env-type batched rewards.

dones Dict[str, Array]

Per-env-type batched terminal flags.

infos Dict[str, Dict[str, Any]]

Per-env-type batched info dicts.

Raises:

Name Type Description
key_mismatch ValueError

If states or actions keys do not match env_keys.

Source code in envrax/multi_vec_env.py
Python
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def step(
    self,
    states: Dict[str, chex.ArrayTree],
    actions: Dict[str, jax.Array],
) -> Tuple[
    Dict[str, jax.Array],
    Dict[str, chex.ArrayTree],
    Dict[str, jax.Array],
    Dict[str, jax.Array],
    Dict[str, Dict[str, Any]],
]:
    """
    Step all env types simultaneously.

    Parameters
    ----------
    states : Dict[str, chex.ArrayTree]
        Per-env-type batched states from a previous reset or step.
    actions : Dict[str, jax.Array]
        Per-env-type batched actions.

    Returns
    -------
    obs : Dict[str, jax.Array]
        Per-env-type batched observations after the step.
    new_states : Dict[str, chex.ArrayTree]
        Per-env-type updated batched states.
    rewards : Dict[str, jax.Array]
        Per-env-type batched rewards.
    dones : Dict[str, jax.Array]
        Per-env-type batched terminal flags.
    infos : Dict[str, Dict[str, Any]]
        Per-env-type batched info dicts.

    Raises
    ------
    key_mismatch : ValueError
        If `states` or `actions` keys do not match `env_keys`.
    """
    if set(states.keys()) != set(self._keys):
        raise ValueError(
            f"MultiVecEnv.step: `states` keys {sorted(states.keys())} "
            f"do not match env keys {sorted(self._keys)}."
        )

    if set(actions.keys()) != set(self._keys):
        raise ValueError(
            f"MultiVecEnv.step: `actions` keys {sorted(actions.keys())} "
            f"do not match env keys {sorted(self._keys)}."
        )

    return self._jit_step(states, actions)

slot_state(states, key, slot_idx)

Extract the single-slot state pytree for one agent.

Parameters:

Name Type Description Default
states Dict[str, ArrayTree]

Per-env-type batched states.

required
key str

Env-type key in env_keys.

required
slot_idx int

Slot index in [0, slots_per_env[key]).

required

Returns:

Name Type Description
single_state ArrayTree

Unbatched state pytree for the chosen slot.

Source code in envrax/multi_vec_env.py
Python
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def slot_state(
    self, states: Dict[str, chex.ArrayTree], key: str, slot_idx: int
) -> chex.ArrayTree:
    """
    Extract the single-slot state pytree for one agent.

    Parameters
    ----------
    states : Dict[str, chex.ArrayTree]
        Per-env-type batched states.
    key : str
        Env-type key in `env_keys`.
    slot_idx : int
        Slot index in `[0, slots_per_env[key])`.

    Returns
    -------
    single_state : chex.ArrayTree
        Unbatched state pytree for the chosen slot.
    """
    return self._envs[key].slot_state(states[key], slot_idx)

render_slot(states, key, slot_idx)

Render a single slot as an RGB frame.

Parameters:

Name Type Description Default
states Dict[str, ArrayTree]

Per-env-type batched states.

required
key str

Env-type key in env_keys.

required
slot_idx int

Slot index in [0, slots_per_env[key]).

required

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (H, W, 3).

Source code in envrax/multi_vec_env.py
Python
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def render_slot(
    self, states: Dict[str, chex.ArrayTree], key: str, slot_idx: int
) -> np.ndarray:
    """
    Render a single slot as an RGB frame.

    Parameters
    ----------
    states : Dict[str, chex.ArrayTree]
        Per-env-type batched states.
    key : str
        Env-type key in `env_keys`.
    slot_idx : int
        Slot index in `[0, slots_per_env[key])`.

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(H, W, 3)`.
    """
    return self._envs[key].render_slot(states[key], slot_idx)

compile(cache_dir=DEFAULT_CACHE_DIR, *, progress=True)

Trigger XLA compilation for all inner envs and warm the multi-step jit.

Parameters:

Name Type Description Default
cache_dir Path | str | None

XLA cache directory. Defaults to <cwd>/.jax_cache.

DEFAULT_CACHE_DIR
progress bool

Show a tqdm progress bar. Default is True.

True
Source code in envrax/multi_vec_env.py
Python
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def compile(
    self,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
    *,
    progress: bool = True,
) -> None:
    """
    Trigger XLA compilation for all inner envs and warm the multi-step jit.

    Parameters
    ----------
    cache_dir : Path | str | None (optional)
        XLA cache directory. Defaults to `<cwd>/.jax_cache`.
    progress : bool (optional)
        Show a `tqdm` progress bar. Default is `True`.
    """
    setup_cache(cache_dir)

    it = (
        tqdm(self._envs.items(), desc="Compiling batched envs", unit="env")
        if progress
        else self._envs.items()
    )
    for _, env in it:
        env.compile(cache_dir=cache_dir)

    _rng = jax.random.key(0)
    _obs, _states = self.reset(_rng)
    _action_keys = jax.random.split(_rng, len(self._keys))
    _dummy_actions = {
        k: jax.vmap(self._envs[k].single_action_space.sample)(
            jax.random.split(_action_keys[i], self._envs[k].n_slots)
        )
        for i, k in enumerate(self._keys)
    }
    self.step(_states, _dummy_actions)