Vectorising with VecEnv¶
API Docs
Excellent work building your first environment! If you wanted to, you could stop there and start using Envrax in your own projects right now for your own RL experiments, but a single environment is quite... inefficient.
Think about it - a single JaxEnv runs one environment for one episode at a time. If you wanted to run over 1 million timesteps to train your policy, that's incredibly sample inefficient and could take weeks to finish training.
What we really need, is a way to make multiple copies of it with randomization automatically built in. Well, that's where VecEnv comes in!
We can wrap the environment in VecEnv and it will operate on a batch of N independent environments simultaneously via jax.vmap.
No process pools, no pickling, no cross-device transfers, just pure JAX-native vectorisation for maximum performance!
Here's an example:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | |
We create an instance of our environment, pass it into VecEnv and provide a number of num_envs to create. It then does the rest without any tweaks to the API! It's that simple!
How It Works¶
VecEnv is ~30 lines of glue around jax.vmap, and the core implementation of the BatchedEnv base class. Here's a quick rundown:
reset(rng)— splitsrngintonum_envssub-keys and vmaps the innerenv.resetover them.step(state, actions)— vmapsenv.stepover the batched state and actions.- Environments auto-reset — after each
step, any env withdone=Trueis automatically reset using the next rng from its own state. The selection between "step output" and "fresh reset" is done with ajnp.wheremask inside the vmapped body, so episode boundaries don't require Python-level control flow.
The auto-reset behaviour is what makes VecEnv "training-ready": you never have to branch on done yourself when collecting rollouts; it does it all for you!
What's BatchedEnv?
BatchedEnv is the abstract base class that says "I step N independent agent results in one call." VecEnv is one strategy that satisfies it (vmap over a single JaxEnv); downstream packages can ship others (e.g. a composite MJX scene). MultiVecEnv accepts any BatchedEnv, so the underlying batching strategy stays pluggable.
Available Attributes¶
Just like Gymnasium [], Envrax's VecEnv provides a small set of attributes and properties that may come in handy during training:
| Item | Description |
|---|---|
vec_env.env |
The wrapped inner JaxEnv |
vec_env.num_envs |
The number of parallel environments |
vec_env.n_slots |
Alias for num_envs, satisfies the BatchedEnv contract |
vec_env.name |
Inner env's class name — used as the default key by MultiVecEnv |
vec_env.config |
The inner environment's config for quick and easy access |
vec_env.single_observation_space |
The per-env observation space |
vec_env.single_action_space |
The per-env action space |
vec_env.observation_space |
The batched observation space with a leading num_envs dim (B) |
vec_env.action_space |
The batched action space with a leading num_envs dim (B) |
JIT Compiling¶
By default, JAX compiles a function lazily on its first real call. For a VecEnv, the first step kicks off XLA compilation and can take anywhere from a couple of seconds up to a minute, depending on env complexity.
This cost can be pretty annoying during a training run, so we've added a compile() method to VecEnv. With this, you can create your own setup stages in advance, and cache the XLA-compiled kernels (default: .jax_cache in the project root) too!
Here's how to use it:
| Python | |
|---|---|
1 2 3 4 | |
With caching in place, subsequent Python processes will reuse the same compiled kernels to drastically reduce future compiling time. This is useful for test runs, when you need to stop and start a training run, or when your program unexpectedly crashes. Those precious seconds make all the difference!
Using Wrappers¶
As you'll see in a future tutorial (Available Wrappers), Envrax comes with a host of environment wrappers out-of-the-box.
To use them with VecEnv, you need to apply them to your JaxEnv first, then pass the wrapped environment to VecEnv:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 | |
This order matters for two reasons:
- Wrappers transform per-env data - the
GrayscaleObservationwrapper expectsuint8[H, W, 3], notuint8[N, H, W, 3]. Putting it outsideVecEnvwould feed it batched arrays it can't handle. VecEnvisn't aJaxEnv- wrappers expect aJaxEnvinstance as their inner env andVecEnvisn't aJaxEnv, it's a basic class wrapper around it.
The make_vec() factory method applies wrappers in this order automatically. We'll cover the full set of factory methods later in the Make Methods tutorial!
RecordVideo is the Exception
RecordVideo writes MP4 files Python-side and is not JIT/vmap-compatible. Use it on a single environment, or render manually via vec_env.render_slot(states, slot_idx=0) and feed an external recorder.
Rendering¶
VecEnv exposes render_slot() and slot_state() for pulling a single environment out of the batch. render_slot extracts one slot and delegates to the inner environment's render:
| Python | |
|---|---|
1 2 | |
This is useful for logging an episode during training without unpacking the batched state yourself, and for any downstream tooling that wants a single agent's state at a time. We'll discuss Rendering more in a future tutorial.
Common Pitfalls¶
Like EnvState, there are a few common "gotchas" to be mindful of:
- Mismatched action shape —
actionsmust have shape(num_envs, ...)with the same dtype as the action space. For aDiscreteaction, that'sjnp.int32[num_envs]. resetwith a single key —VecEnv.resettakes one master key and splits it internally automatically. Don't pre-split your keys!- Trying to use Python-side side-effects inside
step—VecEnvvmaps over the batch, soprint(), file writes, etc. trace and explode. - Forgetting
compile()in benchmarks — the first call will always look slow because XLA is compiling. Callcompile()before timing anything.
Recap¶
To recap:
VecEnv(env, num_envs)usesjax.vmapon your environment for batched rolloutsVecEnvinherits fromBatchedEnv— it's the canonical vmap strategy for that contract- Batched fields all gain a leading
num_envsdimension - Auto-reset on
done=Trueis handled inside the vmapped body via ajnp.wheremask — no Python control flow needed - A small set of attributes (
env,num_envs,n_slots,name,config, plus single and batched space properties) gives quick access to the wrapped environment's metadata - Call
vec_env.compile()to trigger XLA compilation as a separate setup phase, with cached kernels reused across runs - Apply wrappers to your
JaxEnvfirst, then pass the wrapped environment toVecEnv vec_env.render_slot(states, slot_idx=0)extracts one environment from the batch for visual inspection;vec_env.slot_state(states, slot_idx=0)gives you the raw unbatched state pytree
Next, we'll look at using some new classes to make M heterogeneous environments with ease. See you there!
Next Steps¶
-
Multiple Environments
Learn how to manage
Mheterogeneous environments withMultiEnv/MultiVecEnv.