Skip to content

Vectorised

A batched wrapper that runs many copies of a single environment in parallel on one accelerator via jax.vmap.

envrax.vec_env.VecEnv

Bases: Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]

Wraps any JaxEnv to operate over a batch of environments simultaneously.

Parameters:

Name Type Description Default
env JaxEnv

Single-instance environment to vectorise

required
num_envs int

Number of parallel environments

required
Source code in envrax/vec_env.py
Python
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class VecEnv(Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]):
    """
    Wraps any `JaxEnv` to operate over a batch of environments simultaneously.

    Parameters
    ----------
    env : JaxEnv
        Single-instance environment to vectorise
    num_envs : int
        Number of parallel environments
    """

    def __init__(
        self,
        env: JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT],
        num_envs: int,
    ) -> None:
        self.env = env
        self.num_envs = num_envs

    @property
    def config(self) -> ConfigT:
        """Single environment configuration."""
        return self.env.config

    def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, StateT]:
        """
        Reset all `num_envs` environments with independent random starts.

        All returned arrays have a leading batch dimension `B = num_envs`,
        e.g. observations of shape `(B, *obs_shape)`.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs  : chex.Array
            Stacked first observations
        states : EnvState
            Batched environment states
        """
        rngs = jax.random.split(rng, self.num_envs)
        return jax.vmap(self.env.reset)(rngs)

    def step(
        self,
        state: StateT,
        actions: chex.Array,
    ) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
        """
        Advance all environments by one step simultaneously.

        Each environment independently auto-resets when its episode ends.
        All inputs and outputs have a leading batch dimension `B = num_envs`:

        - Observations: `(B, *obs_shape)`
        - Discrete actions: `(B,)` — one int per env
        - Continuous actions: `(B, *action_shape)` — one vector per env
        - Rewards / dones: `(B,)`

        Parameters
        ----------
        state : EnvState
            Batched environment states
        actions  : chex.Array
            One action per environment

        Returns
        -------
        obs  : chex.Array
            Observations after the step
        new_states : EnvState
            Updated batched states
        rewards  : chex.Array
            Per-environment rewards
        dones  : chex.Array
            Per-environment terminal flags
        infos : Dict[str, Any]
            Batched info dict
        """
        return jax.vmap(self._step_env)(state, actions)

    def _step_env(
        self,
        state: StateT,
        action: chex.Array,
    ) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
        """
        Single-env step that auto-resets on episode end.

        When `done` is `True`, returns the observation from a fresh reset
        instead of the terminal observation.

        Parameters
        ----------
        state : EnvState
            Current environment state
        action : chex.Array
            Action to take in the environment

        Returns
        -------
        Tuple of `(obs, new_state, reward, done, info)`, where obs/new_state
        come from reset if done is `True`.
        """
        obs, new_state, reward, done, info = self.env.step(state, action)

        reset_rng, _ = jax.random.split(new_state.rng)
        reset_obs, reset_state = self.env.reset(reset_rng)

        final_obs = jax.lax.cond(done, lambda: reset_obs, lambda: obs)
        final_state = jax.lax.cond(done, lambda: reset_state, lambda: new_state)

        return final_obs, final_state, reward, done, info

    def render(self, state: StateT, *, index: int = 0) -> np.ndarray:
        """
        Render a single environment from the batch.

        Extracts the state at `index` from the batched state pytree and
        delegates to the inner env's `render()`.

        Parameters
        ----------
        state : EnvState
            Batched environment state
        index : int (optional)
            Which environment in the batch to render. Default is `0`.

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(H, W, 3)`
        """
        single_state = jax.tree.map(lambda x: x[index], state)
        return self.env.render(single_state)

    def compile(self, cache_dir: Path | str | None = DEFAULT_CACHE_DIR) -> None:
        """
        Trigger XLA compilation by running a dummy `reset` + `step`.

        Useful when construction and compilation should be separate phases.
        Safe to call multiple times.

        Parameters
        ----------
        cache_dir : Path | str | None (optional)
            XLA cache directory. Defaults to `~/.cache/envrax/xla_cache`.
        """
        setup_cache(cache_dir)
        _key = jax.random.key(0)
        _, _state = self.reset(_key)
        self.step(_state, jnp.zeros(self.num_envs, dtype=jnp.int32))

    @property
    def single_observation_space(self) -> ObsSpaceT:
        """Observation space of a single inner environment."""
        return self.env.observation_space

    @property
    def single_action_space(self) -> ActSpaceT:
        """Action space of a single inner environment."""
        return self.env.action_space

    @property
    def observation_space(self) -> Space:
        """Batched observation space with a leading `num_envs` dimension."""
        return self.env.observation_space.batch(self.num_envs)

    @property
    def action_space(self) -> Space:
        """Batched action space with a leading `num_envs` dimension."""
        return self.env.action_space.batch(self.num_envs)

    def __repr__(self) -> str:
        return f"VecEnv<{self.env!r}, num_envs={self.num_envs}>"

config property

Single environment configuration.

single_observation_space property

Observation space of a single inner environment.

single_action_space property

Action space of a single inner environment.

observation_space property

Batched observation space with a leading num_envs dimension.

action_space property

Batched action space with a leading num_envs dimension.

reset(rng)

Reset all num_envs environments with independent random starts.

All returned arrays have a leading batch dimension B = num_envs, e.g. observations of shape (B, *obs_shape).

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Stacked first observations

states EnvState

Batched environment states

Source code in envrax/vec_env.py
Python
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, StateT]:
    """
    Reset all `num_envs` environments with independent random starts.

    All returned arrays have a leading batch dimension `B = num_envs`,
    e.g. observations of shape `(B, *obs_shape)`.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs  : chex.Array
        Stacked first observations
    states : EnvState
        Batched environment states
    """
    rngs = jax.random.split(rng, self.num_envs)
    return jax.vmap(self.env.reset)(rngs)

step(state, actions)

Advance all environments by one step simultaneously.

Each environment independently auto-resets when its episode ends. All inputs and outputs have a leading batch dimension B = num_envs:

  • Observations: (B, *obs_shape)
  • Discrete actions: (B,) — one int per env
  • Continuous actions: (B, *action_shape) — one vector per env
  • Rewards / dones: (B,)

Parameters:

Name Type Description Default
state EnvState

Batched environment states

required
actions Array

One action per environment

required

Returns:

Name Type Description
obs Array

Observations after the step

new_states EnvState

Updated batched states

rewards Array

Per-environment rewards

dones Array

Per-environment terminal flags

infos Dict[str, Any]

Batched info dict

Source code in envrax/vec_env.py
Python
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def step(
    self,
    state: StateT,
    actions: chex.Array,
) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
    """
    Advance all environments by one step simultaneously.

    Each environment independently auto-resets when its episode ends.
    All inputs and outputs have a leading batch dimension `B = num_envs`:

    - Observations: `(B, *obs_shape)`
    - Discrete actions: `(B,)` — one int per env
    - Continuous actions: `(B, *action_shape)` — one vector per env
    - Rewards / dones: `(B,)`

    Parameters
    ----------
    state : EnvState
        Batched environment states
    actions  : chex.Array
        One action per environment

    Returns
    -------
    obs  : chex.Array
        Observations after the step
    new_states : EnvState
        Updated batched states
    rewards  : chex.Array
        Per-environment rewards
    dones  : chex.Array
        Per-environment terminal flags
    infos : Dict[str, Any]
        Batched info dict
    """
    return jax.vmap(self._step_env)(state, actions)

render(state, *, index=0)

Render a single environment from the batch.

Extracts the state at index from the batched state pytree and delegates to the inner env's render().

Parameters:

Name Type Description Default
state EnvState

Batched environment state

required
index int

Which environment in the batch to render. Default is 0.

0

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (H, W, 3)

Source code in envrax/vec_env.py
Python
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def render(self, state: StateT, *, index: int = 0) -> np.ndarray:
    """
    Render a single environment from the batch.

    Extracts the state at `index` from the batched state pytree and
    delegates to the inner env's `render()`.

    Parameters
    ----------
    state : EnvState
        Batched environment state
    index : int (optional)
        Which environment in the batch to render. Default is `0`.

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(H, W, 3)`
    """
    single_state = jax.tree.map(lambda x: x[index], state)
    return self.env.render(single_state)

compile(cache_dir=DEFAULT_CACHE_DIR)

Trigger XLA compilation by running a dummy reset + step.

Useful when construction and compilation should be separate phases. Safe to call multiple times.

Parameters:

Name Type Description Default
cache_dir Path | str | None

XLA cache directory. Defaults to ~/.cache/envrax/xla_cache.

DEFAULT_CACHE_DIR
Source code in envrax/vec_env.py
Python
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def compile(self, cache_dir: Path | str | None = DEFAULT_CACHE_DIR) -> None:
    """
    Trigger XLA compilation by running a dummy `reset` + `step`.

    Useful when construction and compilation should be separate phases.
    Safe to call multiple times.

    Parameters
    ----------
    cache_dir : Path | str | None (optional)
        XLA cache directory. Defaults to `~/.cache/envrax/xla_cache`.
    """
    setup_cache(cache_dir)
    _key = jax.random.key(0)
    _, _state = self.reset(_key)
    self.step(_state, jnp.zeros(self.num_envs, dtype=jnp.int32))