Multiple Environments¶
As we've seen, VecEnv gives you N parallel copies of a single environment class, but what if you want to train your agent on multiple unique environments?
This is a very common strategy for meta-learning tasks, multi-task training, and when evaluating an agent on multiple environments.
Envrax has built-in support for this via the MultiEnv and MultiVecEnv classes. Each gives you M parallel copies of different env classes. These could be different environments or the same environment but with different observation shapes, action spaces, and configs. The sky's the limit!
As a rule of thumb, if you want:
Nparallel copies of one environment - useVecEnvMdifferent environments, one instance for each - useMultiEnvMdifferent environments withNcopies of each - useMultiVecEnv
MultiEnv¶
API Docs
Implementation example :
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | |
Some key things worth noting:
- Inputs and outputs are Python lists, not stacked arrays. Different environments may have different observation shapes, action shapes and configs. They cannot be stacked in JAX arrays.
reset(rng)takes one master key.MultiEnvsplits it automatically intoMdeterministic sub-keys so the same master key always produces the same per-env starts.compile()is a separate step.MultiEnvdoesn't pre-warm its inner environments by default. Callingmulti.compile()walks the fleet and compiles each one with a progress bar, so you can measure the setup vs. training costs separately.
MultiEnv Attributes¶
| Item | Description |
|---|---|
multi.envs |
The list of inner JaxEnv instances |
multi.num_envs |
The number of environments (M) |
multi.observation_spaces |
A list of per-env observation spaces |
multi.action_spaces |
A list of per-env action spaces |
multi.class_groups |
A dict mapping env class name → list of indices |
MultiVecEnv¶
API Docs
Implementation example :
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
This follows the same pattern as MultiEnv with a slight difference - each element of the returned lists is batched to (n_envs, ...) by its inner VecEnv.
MultiVecEnv Attributes¶
| Item | Description |
|---|---|
multi_vec.vec_envs |
The list of inner VecEnv instances |
multi_vec.num_envs |
The number of VecEnv groups (M) |
multi_vec.total_envs |
Total individual environments across all groups (M × N) |
multi_vec.single_observation_spaces |
A list of per-group unbatched observation spaces |
multi_vec.single_action_spaces |
A list of per-group unbatched action spaces |
multi_vec.observation_spaces |
A list of per-group batched observation spaces |
multi_vec.action_spaces |
A list of per-group batched action spaces |
multi_vec.class_groups |
A dict mapping inner env class name → list of VecEnv indices |
Skip the boilerplate with make_multi / make_multi_vec
The make_multi and make_multi_vec factory methods build these fleets by name in one call - wrappers, JIT, and all. The class-based form here is the underlying mechanism; the factories are the sugar on top.
Use them whenever you can!
Additional Methods¶
MultiEnv and MultiVecEnv share the same extra-API surface for finer-grained control over your fleet of environments.
Per-Env Access¶
For targeted environment resets/steps you can use the utility methods reset_at() and step_at() to reset or step a single environment individually:
| Python | |
|---|---|
1 2 | |
This can be useful for situations like limiting your agents to environment lifetime budgets.
Class Groups¶
When your MultiEnv/MultiVecEnv contains repeat environments (e.g. two BallEnv and one CartPole), you can group indices by class for downstream same-shape batching:
| Python | |
|---|---|
1 2 | |
This is useful if you later want to stack observations for the repeated environment instances into a single batched tensor, perhaps for a policy forward pass or to compute per-task statistics.
Common Pitfalls¶
Using multiple environments at once can be tricky, be mindful of the following "gotchas":
- Different action dtypes - if
env[0]takesint32andenv[1]takesfloat32, build the actions list element by element; don't try tojnp.stackthem. - Forgetting
compile()-MultiEnvandMultiVecEnvdon't pre-warm their inner environments. Without an explicitmulti.compile(), your firststepcall will pay the compile cost for every env in the fleet sequentially.
Recap¶
To recap:
MultiEnvmanagesMheterogeneousJaxEnvinstances;MultiVecEnvmanagesMVecEnvgroups- Inputs and outputs are lists because observation shapes can differ across environments
reset_atandstep_atlet you touch a single env without disturbing the restclass_groupsmaps class name → indices for downstream same-shape batching- Call
.compile()explicitly — these managers default to deferred compilation
Next up, we'll explore how Envrax's environment registry works so you can use canonical names for building environments instead of classes.
Next Steps¶
-
Environment Registry
Learn how to use Envrax's environment registry.