Skip to content

Vectorised

A batched wrapper that runs many copies of a single environment in parallel on one accelerator via jax.vmap.

envrax.vec_env.VecEnv

Bases: BatchedEnv, Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]

Wraps any JaxEnv to operate over a batch of environments simultaneously.

Canonical BatchedEnv implementation: num_envs independent copies of one JaxEnv stepped in parallel via jax.vmap, with per-slot auto-reset.

Parameters:

Name Type Description Default
env JaxEnv

Single-instance environment to vectorise

required
num_envs int

Number of parallel environments (n_slots)

required
Source code in envrax/vec_env.py
Python
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class VecEnv(BatchedEnv, Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]):
    """
    Wraps any `JaxEnv` to operate over a batch of environments simultaneously.

    Canonical `BatchedEnv` implementation: `num_envs` independent copies of
    one `JaxEnv` stepped in parallel via `jax.vmap`, with per-slot auto-reset.

    Parameters
    ----------
    env : JaxEnv
        Single-instance environment to vectorise
    num_envs : int
        Number of parallel environments (`n_slots`)
    """

    def __init__(
        self,
        env: JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT],
        num_envs: int,
    ) -> None:
        self.env = env
        self.num_envs = num_envs

    @property
    def n_slots(self) -> int:
        """Number of parallel slots (= `num_envs`)."""
        return self.num_envs

    @property
    def name(self) -> str:
        """
        Inner environment's `name`. Used as the default key by `MultiVecEnv`.

        Picks up wrapper delegation, so `VecEnv(JitWrapper(BallEnv()))`
        still reports `"BallEnv"` rather than `"JitWrapper"`.

        Returns
        -------
        name : str
            The wrapped environment's `name`.
        """
        return self.env.name

    @property
    def config(self) -> ConfigT:
        """Single environment configuration."""
        return self.env.config

    def reset(self, rng: chex.PRNGKey) -> Tuple[jax.Array, StateT]:
        """
        Reset all `num_envs` environments with independent random starts.

        All returned arrays have a leading batch dimension `B = num_envs`,
        e.g. observations of shape `(B, *obs_shape)`.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs  : jax.Array
            Stacked first observations
        states : EnvState
            Batched environment states
        """
        rngs = jax.random.split(rng, self.num_envs)
        return jax.vmap(self.env.reset)(rngs)

    def step(
        self,
        state: StateT,
        actions: jax.Array,
    ) -> Tuple[jax.Array, StateT, jax.Array, jax.Array, Dict[str, Any]]:
        """
        Advance all environments by one step simultaneously.

        Each environment independently auto-resets when its episode ends.
        All inputs and outputs have a leading batch dimension `B = num_envs`:

        - Observations: `(B, *obs_shape)`
        - Discrete actions: `(B,)` — one int per env
        - Continuous actions: `(B, *action_shape)` — one vector per env
        - Rewards / dones: `(B,)`

        Parameters
        ----------
        state : EnvState
            Batched environment states
        actions  : jax.Array
            One action per environment

        Returns
        -------
        obs  : jax.Array
            Observations after the step
        new_states : EnvState
            Updated batched states
        rewards  : jax.Array
            Per-environment rewards
        dones  : jax.Array
            Per-environment terminal flags
        infos : Dict[str, Any]
            Batched info dict
        """
        return jax.vmap(self._step_env)(state, actions)

    def _step_env(
        self,
        state: StateT,
        action: jax.Array,
    ) -> Tuple[jax.Array, StateT, jax.Array, jax.Array, Dict[str, Any]]:
        """
        Single-env step that auto-resets on episode end.

        When `done` is `True`, returns the observation from a fresh reset
        instead of the terminal observation.

        Parameters
        ----------
        state : EnvState
            Current environment state
        action : jax.Array
            Action to take in the environment

        Returns
        -------
        Tuple of `(obs, new_state, reward, done, info)`, where obs/new_state
        come from reset if done is `True`.
        """
        obs, new_state, reward, done, info = self.env.step(state, action)

        reset_rng, _ = jax.random.split(new_state.rng)
        reset_obs, reset_state = self.env.reset(reset_rng)

        final_obs = jnp.where(done, reset_obs, obs)
        final_state = jax.tree.map(
            lambda r, n: jnp.where(done, r, n), reset_state, new_state
        )

        return final_obs, final_state, reward, done, info

    def slot_state(self, state: StateT, slot_idx: int) -> StateT:
        """
        Extract the state pytree for a single slot from the batched state.

        Parameters
        ----------
        state : EnvState
            Batched environment state
        slot_idx : int
            Slot index in `[0, num_envs)`

        Returns
        -------
        single_state : EnvState
            Unbatched state pytree for the chosen slot.
        """
        return jax.tree.map(lambda x: x[slot_idx], state)

    def render_slot(self, state: StateT, slot_idx: int) -> np.ndarray:
        """
        Render a single environment from the batch.

        Parameters
        ----------
        state : EnvState
            Batched environment state
        slot_idx : int
            Slot index in `[0, num_envs)`

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(H, W, 3)`
        """
        return self.env.render(self.slot_state(state, slot_idx))

    def compile(self, cache_dir: Path | str | None = DEFAULT_CACHE_DIR) -> None:
        """
        Trigger XLA compilation by running dummy `reset` + `step`.

        Runs once with `done=False` (typical path) and once with `done=True`
        on the first slot to warm both branches of the auto-reset path.

        Parameters
        ----------
        cache_dir : Path | str | None (optional)
            XLA cache directory. Defaults to `<cwd>/.jax_cache`.
        """
        setup_cache(cache_dir)

        _key = jax.random.key(0)
        _, _state = self.reset(_key)
        _action_rngs = jax.random.split(_key, self.num_envs)
        _dummy_actions = jax.vmap(self.env.action_space.sample)(_action_rngs)

        self.step(_state, _dummy_actions)

        _forced_done = _state.done.at[0].set(jnp.bool_(True))
        _state_done = _state.__replace__(done=_forced_done)
        self.step(_state_done, _dummy_actions)

    @property
    def single_observation_space(self) -> ObsSpaceT:
        """Observation space of a single inner environment."""
        return self.env.observation_space

    @property
    def single_action_space(self) -> ActSpaceT:
        """Action space of a single inner environment."""
        return self.env.action_space

    @property
    def observation_space(self) -> Space:
        """Batched observation space with a leading `num_envs` dimension."""
        return self.env.observation_space.batch(self.num_envs)

    @property
    def action_space(self) -> Space:
        """Batched action space with a leading `num_envs` dimension."""
        return self.env.action_space.batch(self.num_envs)

    def __repr__(self) -> str:
        return f"VecEnv<{self.env!r}, num_envs={self.num_envs}>"

n_slots property

Number of parallel slots (= num_envs).

name property

Inner environment's name. Used as the default key by MultiVecEnv.

Picks up wrapper delegation, so VecEnv(JitWrapper(BallEnv())) still reports "BallEnv" rather than "JitWrapper".

Returns:

Name Type Description
name str

The wrapped environment's name.

config property

Single environment configuration.

single_observation_space property

Observation space of a single inner environment.

single_action_space property

Action space of a single inner environment.

observation_space property

Batched observation space with a leading num_envs dimension.

action_space property

Batched action space with a leading num_envs dimension.

reset(rng)

Reset all num_envs environments with independent random starts.

All returned arrays have a leading batch dimension B = num_envs, e.g. observations of shape (B, *obs_shape).

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Stacked first observations

states EnvState

Batched environment states

Source code in envrax/vec_env.py
Python
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def reset(self, rng: chex.PRNGKey) -> Tuple[jax.Array, StateT]:
    """
    Reset all `num_envs` environments with independent random starts.

    All returned arrays have a leading batch dimension `B = num_envs`,
    e.g. observations of shape `(B, *obs_shape)`.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs  : jax.Array
        Stacked first observations
    states : EnvState
        Batched environment states
    """
    rngs = jax.random.split(rng, self.num_envs)
    return jax.vmap(self.env.reset)(rngs)

step(state, actions)

Advance all environments by one step simultaneously.

Each environment independently auto-resets when its episode ends. All inputs and outputs have a leading batch dimension B = num_envs:

  • Observations: (B, *obs_shape)
  • Discrete actions: (B,) — one int per env
  • Continuous actions: (B, *action_shape) — one vector per env
  • Rewards / dones: (B,)

Parameters:

Name Type Description Default
state EnvState

Batched environment states

required
actions Array

One action per environment

required

Returns:

Name Type Description
obs Array

Observations after the step

new_states EnvState

Updated batched states

rewards Array

Per-environment rewards

dones Array

Per-environment terminal flags

infos Dict[str, Any]

Batched info dict

Source code in envrax/vec_env.py
Python
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def step(
    self,
    state: StateT,
    actions: jax.Array,
) -> Tuple[jax.Array, StateT, jax.Array, jax.Array, Dict[str, Any]]:
    """
    Advance all environments by one step simultaneously.

    Each environment independently auto-resets when its episode ends.
    All inputs and outputs have a leading batch dimension `B = num_envs`:

    - Observations: `(B, *obs_shape)`
    - Discrete actions: `(B,)` — one int per env
    - Continuous actions: `(B, *action_shape)` — one vector per env
    - Rewards / dones: `(B,)`

    Parameters
    ----------
    state : EnvState
        Batched environment states
    actions  : jax.Array
        One action per environment

    Returns
    -------
    obs  : jax.Array
        Observations after the step
    new_states : EnvState
        Updated batched states
    rewards  : jax.Array
        Per-environment rewards
    dones  : jax.Array
        Per-environment terminal flags
    infos : Dict[str, Any]
        Batched info dict
    """
    return jax.vmap(self._step_env)(state, actions)

slot_state(state, slot_idx)

Extract the state pytree for a single slot from the batched state.

Parameters:

Name Type Description Default
state EnvState

Batched environment state

required
slot_idx int

Slot index in [0, num_envs)

required

Returns:

Name Type Description
single_state EnvState

Unbatched state pytree for the chosen slot.

Source code in envrax/vec_env.py
Python
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def slot_state(self, state: StateT, slot_idx: int) -> StateT:
    """
    Extract the state pytree for a single slot from the batched state.

    Parameters
    ----------
    state : EnvState
        Batched environment state
    slot_idx : int
        Slot index in `[0, num_envs)`

    Returns
    -------
    single_state : EnvState
        Unbatched state pytree for the chosen slot.
    """
    return jax.tree.map(lambda x: x[slot_idx], state)

render_slot(state, slot_idx)

Render a single environment from the batch.

Parameters:

Name Type Description Default
state EnvState

Batched environment state

required
slot_idx int

Slot index in [0, num_envs)

required

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (H, W, 3)

Source code in envrax/vec_env.py
Python
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def render_slot(self, state: StateT, slot_idx: int) -> np.ndarray:
    """
    Render a single environment from the batch.

    Parameters
    ----------
    state : EnvState
        Batched environment state
    slot_idx : int
        Slot index in `[0, num_envs)`

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(H, W, 3)`
    """
    return self.env.render(self.slot_state(state, slot_idx))

compile(cache_dir=DEFAULT_CACHE_DIR)

Trigger XLA compilation by running dummy reset + step.

Runs once with done=False (typical path) and once with done=True on the first slot to warm both branches of the auto-reset path.

Parameters:

Name Type Description Default
cache_dir Path | str | None

XLA cache directory. Defaults to <cwd>/.jax_cache.

DEFAULT_CACHE_DIR
Source code in envrax/vec_env.py
Python
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def compile(self, cache_dir: Path | str | None = DEFAULT_CACHE_DIR) -> None:
    """
    Trigger XLA compilation by running dummy `reset` + `step`.

    Runs once with `done=False` (typical path) and once with `done=True`
    on the first slot to warm both branches of the auto-reset path.

    Parameters
    ----------
    cache_dir : Path | str | None (optional)
        XLA cache directory. Defaults to `<cwd>/.jax_cache`.
    """
    setup_cache(cache_dir)

    _key = jax.random.key(0)
    _, _state = self.reset(_key)
    _action_rngs = jax.random.split(_key, self.num_envs)
    _dummy_actions = jax.vmap(self.env.action_space.sample)(_action_rngs)

    self.step(_state, _dummy_actions)

    _forced_done = _state.done.at[0].set(jnp.bool_(True))
    _state_done = _state.__replace__(done=_forced_done)
    self.step(_state_done, _dummy_actions)