Multiple Environments¶
As we've seen, VecEnv gives you N parallel copies of a single environment class, but what if you want to train your agent on multiple unique environments?
This is a very common strategy for meta-learning tasks, multi-task training, and when evaluating an agent on multiple environments.
Envrax has built-in support for this via the MultiEnv and MultiVecEnv classes. Each gives you M parallel copies of different environment classes. These could be different environments or the same environment but with different observation shapes, action spaces, and configs. The sky's the limit!
As a rule of thumb, if you want:
Nparallel copies of one environment - useVecEnvMdifferent environments, one instance for each - useMultiEnvMdifferent environments withNcopies of each (or any mix ofBatchedEnvstrategies) - useMultiVecEnv
MultiEnv¶
API Docs
MultiEnv holds M JaxEnv instances keyed by environment name and dispatches reset/step via a Python loop. Each inner environment keeps its own compile cycle (typically via JitWrapper) — MultiEnv adds no outer jax.jit boundary. Reach for MultiVecEnv when you need a single jitted dispatch over batched envs instead.
Implementation example :
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | |
Some key things worth noting:
- Inputs and outputs are
dicts keyed by environment name, not Python lists. Different environments may have different observation shapes, action shapes and configs — keying by name keeps everything explicit and easy to look up. - Keys are inferred from each environment's
nameby default, with_0/_1suffixes when duplicates appear. Wrappers likeJitWrapperdelegatenameto the inner environment, soJitWrapper(BallEnv()).name == "BallEnv". Pass adictdirectly for explicit control. reset(rng)takes one master key.MultiEnvsplits it automatically into one sub-key per inner environment so the same master key always produces the same per-environment starts.compile()is a separate step.MultiEnvdoesn't pre-warm its inner environments by default. Callingmulti.compile()walks the fleet and compiles eachJitWrapper-wrapped environment with a progress bar, so you can measure the setup vs. training costs separately.
List vs. dict input¶
Just like MultiVecEnv, MultiEnv accepts either form:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 | |
Iteration order is preserved in both forms.
MultiEnv Attributes¶
| Item | Description |
|---|---|
multi.envs |
The dict of inner JaxEnv instances |
multi.env_keys |
Ordered list of environment-type keys |
multi.n_envs |
The number of environments (M) |
multi.observation_spaces |
A dict of per-env observation spaces |
multi.action_spaces |
A dict of per-env action spaces |
multi.observation_shapes |
A dict of per-env observation shapes (s.shape) |
multi.action_shapes |
A dict of per-env action shapes |
multi.observation_sizes |
A dict of per-env flat observation sizes (prod(s.shape)) |
multi.action_sizes |
A dict of per-env flat action sizes |
multi.observation_dtypes |
A dict of per-env observation dtypes |
multi.action_dtypes |
A dict of per-env action dtypes |
MultiVecEnv¶
API Docs
MultiVecEnv is the JAX-native sibling of MultiEnv. It holds M BatchedEnv instances and steps them all together inside a single jax.jit boundary, so the cross-environment dispatch loop unrolls at trace time and there's no per-call Python overhead between groups.
Each inner BatchedEnv handles its own internal batching however it likes. VecEnv is the canonical vmap strategy, but downstream packages can add others (e.g. composite MJX scenes) and slot them into the same MultiVecEnv without any changes to envrax.
Implementation example :
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | |
Some key differences from MultiEnv:
- Inputs and outputs are
dicts keyed by environment name, not Python lists. State is a proper JAX pytree —jax.tree.map,jax.tree.leaves, and friends all work directly on the returnedstatesdict. - One
jax.jitboundary per step. The Pythonforloop over inner environments runs at trace time, so a single XLA computation dispatches every inner kernel with no per-call Python overhead. - Keys are inferred from each environment's
nameby default, with_0/_1suffixes when duplicates appear. Pass adictdirectly for explicit control.
List vs. dict input¶
The dict keys in the example above came from VecEnv.name, which defaults to the inner JaxEnv's class name. When you supply a list, MultiVecEnv derives the keys for you:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
For full control over keys (e.g. task labels like "task_a", "task_b"), pass a dict directly:
| Python | |
|---|---|
1 2 3 4 | |
Iteration order is preserved in both forms.
MultiVecEnv Attributes¶
| Item | Description |
|---|---|
multi_vec.envs |
The dict of inner BatchedEnv instances |
multi_vec.env_keys |
Ordered list of environment-type keys |
multi_vec.n_envs |
The number of distinct environment types (M) |
multi_vec.total_slots |
Total individual agent slots across all groups |
multi_vec.slots_per_env |
A dict of per-group slot counts |
multi_vec.single_observation_spaces |
A dict of per-group unbatched observation spaces |
multi_vec.single_action_spaces |
A dict of per-group unbatched action spaces |
multi_vec.single_observation_shapes |
A dict of per-group unbatched observation shapes |
multi_vec.single_action_shapes |
A dict of per-group unbatched action shapes |
multi_vec.single_observation_sizes |
A dict of per-group flat unbatched observation sizes (prod(s.shape)) |
multi_vec.single_action_sizes |
A dict of per-group flat unbatched action sizes |
multi_vec.single_observation_dtypes |
A dict of per-group unbatched observation dtypes |
multi_vec.single_action_dtypes |
A dict of per-group unbatched action dtypes |
Skip the boilerplate with make_multi / make_multi_vec
The make_multi and make_multi_vec factory methods build these fleets in one call — wrappers, JIT, and all. The class-based form here is the underlying mechanism; the factories are the sugar on top!
Use them whenever you can!
Additional Methods¶
Both MultiEnv and MultiVecEnv share the same dict-keyed surface for finer-grained control.
Per-Environment Access¶
Inner environments are directly accessible by key:
| Python | |
|---|---|
1 2 3 | |
For MultiVecEnv, pulling out a single slot's state from a batched environment (e.g. one rollout) uses slot_state / render_slot:
| Python | |
|---|---|
1 2 | |
MultiEnv doesn't need these — its inner environments aren't batched, so multi.envs["BallEnv"].render(state) does the same job directly.
Padding sizes¶
Both classes expose pad_dims(), which returns the largest flat action and observation sizes across the fleet as an (action, observation) tuple:
| Python | |
|---|---|
1 | |
This is useful when you need to vmap a single jitted function (or feed one shared policy network) over environments that don't share the same action or observation shapes. Sizes are computed as prod(space.shape), so multi-dim observations are handled correctly.
For MultiVecEnv the sizes come from the unbatched per-group spaces (i.e., single_*_sizes) — that's the dimension a per-environment policy normally uses.
Common Pitfalls¶
Using multiple environments at once can be tricky, be mindful of the following "gotchas":
- Different action dtypes - if one environment takes
int32and another takesfloat32, build the actions dict element by element; don't try tojnp.stackthem. - Mismatched keys on
step- thestatesandactionsdicts must have exactly the same keys asmulti.env_keys. A missing or extra key raises aValueErrorbefore the inner step runs. - Forgetting
compile()- neither class pre-warms its inner environments. Without an explicitmulti.compile(), your firststepcall will pay the compile cost for every inner environment sequentially.
Recap¶
To recap:
MultiEnvmanagesMheterogeneousJaxEnvinstances;MultiVecEnvmanagesMheterogeneousBatchedEnvinstances (VecEnvbeing the canonical one)- Both accept either a list (auto-keyed from
env.namewith_0/_1suffixes on duplicates) or a dict (keys used verbatim) - Both return dicts keyed by environment name —
MultiVecEnv's state is a proper JAX pytree - Inner environments are accessed via
multi.envs[key]; for batched slot extraction inMultiVecEnv, useslot_state/render_slot MultiVecEnvis fully JAX-native — its step runs inside onejax.jitboundary with no per-call Python overhead between groups- Call
.compile()explicitly — these managers default to deferred compilation
Next up, we'll explore how Envrax's environment registry works so you can use canonical names for building environments instead of classes.
Next Steps¶
-
Environment Registry
Learn how to use Envrax's environment registry.