Skip to content

Multi-Environment

Classes for composing several different environments (and their vectorised variants) under a single unified handle.

envrax.multi_env.MultiEnv

Manages M heterogeneous JaxEnv instances as a single unit. Useful for holding M different JaxEnvs — with potentially different classes, configs, and shapes.

Use .class_groups to identify which indices share a class for downstream batching of same-shape observations.

Parameters:

Name Type Description Default
envs List[JaxEnv]

List of already-constructed environments.

required
Source code in envrax/multi_env.py
Python
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class MultiEnv:
    """
    Manages `M` heterogeneous `JaxEnv` instances as a single unit.
    Useful for holding `M` different `JaxEnv`s — with potentially different
    classes, configs, and shapes.

    Use `.class_groups` to identify which indices share a class for
    downstream batching of same-shape observations.

    Parameters
    ----------
    envs : List[JaxEnv]
        List of already-constructed environments.
    """

    def __init__(self, envs: List[JaxEnv]) -> None:
        if not envs:
            raise ValueError("MultiEnv requires at least one environment.")

        self._envs = envs
        self._class_groups = _build_class_groups(envs)

    @property
    def num_envs(self) -> int:
        """Number of environments (`M`)."""
        return len(self._envs)

    @property
    def envs(self) -> List[JaxEnv]:
        """The inner environment instances."""
        return self._envs

    @property
    def observation_spaces(self) -> List[Space]:
        """Per-env observation spaces."""
        return [env.observation_space for env in self._envs]

    @property
    def action_spaces(self) -> List[Space]:
        """Per-env action spaces."""
        return [env.action_space for env in self._envs]

    @property
    def class_groups(self) -> Dict[str, List[int]]:
        """
        Env class name → list of indices.

        Useful for downstream code that wants to batch operations across
        envs of the same class (e.g. stacking observations with matching
        shapes).
        """
        return self._class_groups

    def reset(self, rng: chex.PRNGKey) -> Tuple[List[chex.Array], List[EnvState]]:
        """
        Reset all `M` environments with independent PRNG keys.

        Splits `rng` into `M` sub-keys deterministically. Same master key
        produces the same per-env keys for full reproducibility.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        observations : List[chex.Array]
            Per-env initial observations
        states : List[EnvState]
            Per-env initial states
        """
        rngs = jax.random.split(rng, self.num_envs)
        obs_list: List[chex.Array] = []
        state_list: List[EnvState] = []

        for i, env in enumerate(self._envs):
            obs, state = env.reset(rngs[i])
            obs_list.append(obs)
            state_list.append(state)

        return obs_list, state_list

    def step(
        self,
        states: List[EnvState],
        actions: List[chex.Array],
    ) -> Tuple[
        List[chex.Array],
        List[EnvState],
        List[chex.Array],
        List[chex.Array],
        List[Dict[str, Any]],
    ]:
        """
        Step all `M` environments simultaneously.

        Parameters
        ----------
        states : List[EnvState]
            Per-env states from a previous reset or step
        actions : List[chex.Array]
            Per-env actions matching each env's action space

        Returns
        -------
        observations : List[chex.Array]
            Per-env observations after the step
        new_states : List[EnvState]
            Per-env updated states
        rewards : List[chex.Array]
            Per-env scalar rewards
        dones : List[chex.Array]
            Per-env terminal flags
        infos : List[Dict[str, Any]]
            Per-env info dicts

        Raises
        ------
        length_mismatch : ValueError
            If `len(states)` or `len(actions)` does not match `num_envs`.
        """
        if len(states) != self.num_envs or len(actions) != self.num_envs:
            raise ValueError(
                f"MultiEnv.step: expected {self.num_envs} states and actions, "
                f"got {len(states)} states and {len(actions)} actions."
            )

        results = [
            env.step(state, action)
            for env, state, action in zip(self._envs, states, actions)
        ]
        return (
            [r[0] for r in results],
            [r[1] for r in results],
            [r[2] for r in results],
            [r[3] for r in results],
            [r[4] for r in results],
        )

    def reset_at(self, idx: int, rng: chex.PRNGKey) -> Tuple[chex.Array, EnvState]:
        """
        Reset a single environment by index.

        Parameters
        ----------
        idx : int
            Index of the environment to reset
        rng : chex.PRNGKey
            JAX PRNG key for the reset

        Returns
        -------
        obs : chex.Array
            Initial observation
        state : EnvState
            Initial state
        """
        return self._envs[idx].reset(rng)

    def step_at(
        self,
        idx: int,
        state: EnvState,
        action: chex.Array,
    ) -> Tuple[chex.Array, EnvState, chex.Array, chex.Array, Dict[str, Any]]:
        """
        Step a single environment by index.

        Parameters
        ----------
        idx : int
            Index of the environment to step
        state : EnvState
            Current state of the environment
        action : chex.Array
            Action to take

        Returns
        -------
        obs : chex.Array
            Observation after the step
        new_state : EnvState
            Updated state
        reward : chex.Array
            Scalar reward
        done : chex.Array
            Terminal flag
        info : Dict[str, Any]
            Info dict
        """
        return self._envs[idx].step(state, action)

    def compile(self, *, progress: bool = True) -> None:
        """
        Trigger XLA compilation for all JIT-wrapped environments.

        Calls `compile()` on each inner env that is a `JitWrapper`.
        Environments without JIT wrapping are silently skipped.

        Parameters
        ----------
        progress : bool (optional)
            Show a `tqdm` progress bar. Default is `True`.
        """
        jit_envs = [
            (i, env) for i, env in enumerate(self._envs) if isinstance(env, JitWrapper)
        ]
        if not jit_envs:
            return

        it = tqdm(jit_envs, desc="Compiling envs", unit="env") if progress else jit_envs
        for _, env in it:
            env.compile()

    def __len__(self) -> int:
        return len(self._envs)

    def __repr__(self) -> str:
        env_info = ", ".join(type(e).__name__ for e in self._envs)
        return f"MultiEnv([{env_info}], num_envs={self.num_envs})"

num_envs property

Number of environments (M).

envs property

The inner environment instances.

observation_spaces property

Per-env observation spaces.

action_spaces property

Per-env action spaces.

class_groups property

Env class name → list of indices.

Useful for downstream code that wants to batch operations across envs of the same class (e.g. stacking observations with matching shapes).

reset(rng)

Reset all M environments with independent PRNG keys.

Splits rng into M sub-keys deterministically. Same master key produces the same per-env keys for full reproducibility.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
observations List[Array]

Per-env initial observations

states List[EnvState]

Per-env initial states

Source code in envrax/multi_env.py
Python
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def reset(self, rng: chex.PRNGKey) -> Tuple[List[chex.Array], List[EnvState]]:
    """
    Reset all `M` environments with independent PRNG keys.

    Splits `rng` into `M` sub-keys deterministically. Same master key
    produces the same per-env keys for full reproducibility.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    observations : List[chex.Array]
        Per-env initial observations
    states : List[EnvState]
        Per-env initial states
    """
    rngs = jax.random.split(rng, self.num_envs)
    obs_list: List[chex.Array] = []
    state_list: List[EnvState] = []

    for i, env in enumerate(self._envs):
        obs, state = env.reset(rngs[i])
        obs_list.append(obs)
        state_list.append(state)

    return obs_list, state_list

step(states, actions)

Step all M environments simultaneously.

Parameters:

Name Type Description Default
states List[EnvState]

Per-env states from a previous reset or step

required
actions List[Array]

Per-env actions matching each env's action space

required

Returns:

Name Type Description
observations List[Array]

Per-env observations after the step

new_states List[EnvState]

Per-env updated states

rewards List[Array]

Per-env scalar rewards

dones List[Array]

Per-env terminal flags

infos List[Dict[str, Any]]

Per-env info dicts

Raises:

Name Type Description
length_mismatch ValueError

If len(states) or len(actions) does not match num_envs.

Source code in envrax/multi_env.py
Python
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def step(
    self,
    states: List[EnvState],
    actions: List[chex.Array],
) -> Tuple[
    List[chex.Array],
    List[EnvState],
    List[chex.Array],
    List[chex.Array],
    List[Dict[str, Any]],
]:
    """
    Step all `M` environments simultaneously.

    Parameters
    ----------
    states : List[EnvState]
        Per-env states from a previous reset or step
    actions : List[chex.Array]
        Per-env actions matching each env's action space

    Returns
    -------
    observations : List[chex.Array]
        Per-env observations after the step
    new_states : List[EnvState]
        Per-env updated states
    rewards : List[chex.Array]
        Per-env scalar rewards
    dones : List[chex.Array]
        Per-env terminal flags
    infos : List[Dict[str, Any]]
        Per-env info dicts

    Raises
    ------
    length_mismatch : ValueError
        If `len(states)` or `len(actions)` does not match `num_envs`.
    """
    if len(states) != self.num_envs or len(actions) != self.num_envs:
        raise ValueError(
            f"MultiEnv.step: expected {self.num_envs} states and actions, "
            f"got {len(states)} states and {len(actions)} actions."
        )

    results = [
        env.step(state, action)
        for env, state, action in zip(self._envs, states, actions)
    ]
    return (
        [r[0] for r in results],
        [r[1] for r in results],
        [r[2] for r in results],
        [r[3] for r in results],
        [r[4] for r in results],
    )

reset_at(idx, rng)

Reset a single environment by index.

Parameters:

Name Type Description Default
idx int

Index of the environment to reset

required
rng PRNGKey

JAX PRNG key for the reset

required

Returns:

Name Type Description
obs Array

Initial observation

state EnvState

Initial state

Source code in envrax/multi_env.py
Python
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def reset_at(self, idx: int, rng: chex.PRNGKey) -> Tuple[chex.Array, EnvState]:
    """
    Reset a single environment by index.

    Parameters
    ----------
    idx : int
        Index of the environment to reset
    rng : chex.PRNGKey
        JAX PRNG key for the reset

    Returns
    -------
    obs : chex.Array
        Initial observation
    state : EnvState
        Initial state
    """
    return self._envs[idx].reset(rng)

step_at(idx, state, action)

Step a single environment by index.

Parameters:

Name Type Description Default
idx int

Index of the environment to step

required
state EnvState

Current state of the environment

required
action Array

Action to take

required

Returns:

Name Type Description
obs Array

Observation after the step

new_state EnvState

Updated state

reward Array

Scalar reward

done Array

Terminal flag

info Dict[str, Any]

Info dict

Source code in envrax/multi_env.py
Python
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def step_at(
    self,
    idx: int,
    state: EnvState,
    action: chex.Array,
) -> Tuple[chex.Array, EnvState, chex.Array, chex.Array, Dict[str, Any]]:
    """
    Step a single environment by index.

    Parameters
    ----------
    idx : int
        Index of the environment to step
    state : EnvState
        Current state of the environment
    action : chex.Array
        Action to take

    Returns
    -------
    obs : chex.Array
        Observation after the step
    new_state : EnvState
        Updated state
    reward : chex.Array
        Scalar reward
    done : chex.Array
        Terminal flag
    info : Dict[str, Any]
        Info dict
    """
    return self._envs[idx].step(state, action)

compile(*, progress=True)

Trigger XLA compilation for all JIT-wrapped environments.

Calls compile() on each inner env that is a JitWrapper. Environments without JIT wrapping are silently skipped.

Parameters:

Name Type Description Default
progress bool

Show a tqdm progress bar. Default is True.

True
Source code in envrax/multi_env.py
Python
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def compile(self, *, progress: bool = True) -> None:
    """
    Trigger XLA compilation for all JIT-wrapped environments.

    Calls `compile()` on each inner env that is a `JitWrapper`.
    Environments without JIT wrapping are silently skipped.

    Parameters
    ----------
    progress : bool (optional)
        Show a `tqdm` progress bar. Default is `True`.
    """
    jit_envs = [
        (i, env) for i, env in enumerate(self._envs) if isinstance(env, JitWrapper)
    ]
    if not jit_envs:
        return

    it = tqdm(jit_envs, desc="Compiling envs", unit="env") if progress else jit_envs
    for _, env in it:
        env.compile()

envrax.multi_vec_env.MultiVecEnv

Manages M heterogeneous VecEnv instances as a single unit. Useful for holding M different VecEnvs — with potentially different classes, configs, and shapes.

Use .class_groups to identify which indices share an inner env class for downstream batching.

Parameters:

Name Type Description Default
vec_envs List[VecEnv]

List of already-constructed vectorised environments.

required
Source code in envrax/multi_vec_env.py
Python
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class MultiVecEnv:
    """
    Manages `M` heterogeneous `VecEnv` instances as a single unit.
    Useful for holding `M` different `VecEnv`s — with potentially different
    classes, configs, and shapes.

    Use `.class_groups` to identify which indices share an inner env
    class for downstream batching.

    Parameters
    ----------
    vec_envs : List[VecEnv]
        List of already-constructed vectorised environments.
    """

    def __init__(self, vec_envs: List[VecEnv]) -> None:
        if not vec_envs:
            raise ValueError("MultiVecEnv requires at least one VecEnv.")

        self._vec_envs = vec_envs
        self._class_groups = _build_class_groups(vec_envs)

    @property
    def num_envs(self) -> int:
        """Number of VecEnv groups (`M`)."""
        return len(self._vec_envs)

    @property
    def total_envs(self) -> int:
        """Total number of individual environments across all groups."""
        return sum(v.num_envs for v in self._vec_envs)

    @property
    def vec_envs(self) -> List[VecEnv]:
        """The inner VecEnv instances."""
        return self._vec_envs

    @property
    def observation_spaces(self) -> List[Space]:
        """Per-group batched observation spaces."""
        return [v.observation_space for v in self._vec_envs]

    @property
    def action_spaces(self) -> List[Space]:
        """Per-group batched action spaces."""
        return [v.action_space for v in self._vec_envs]

    @property
    def single_observation_spaces(self) -> List[Space]:
        """Per-group unbatched observation spaces."""
        return [v.single_observation_space for v in self._vec_envs]

    @property
    def single_action_spaces(self) -> List[Space]:
        """Per-group unbatched action spaces."""
        return [v.single_action_space for v in self._vec_envs]

    @property
    def class_groups(self) -> Dict[str, List[int]]:
        """
        Inner env class name → list of VecEnv indices.

        Useful for downstream code that wants to batch operations across
        groups sharing the same inner env class.
        """
        return self._class_groups

    def reset(self, rng: chex.PRNGKey) -> Tuple[List[chex.Array], List[EnvState]]:
        """
        Reset all `M` vectorised environment groups with independent PRNG keys.

        Each `VecEnv` receives one sub-key and splits it internally across
        its own parallel copies.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        observations : List[chex.Array]
            Per-group batched observations
        states : List[EnvState]
            Per-group batched states
        """
        rngs = jax.random.split(rng, self.num_envs)
        obs_list: List[chex.Array] = []
        state_list: List[EnvState] = []

        for i, vec in enumerate(self._vec_envs):
            obs, state = vec.reset(rngs[i])
            obs_list.append(obs)
            state_list.append(state)

        return obs_list, state_list

    def step(
        self,
        states: List[EnvState],
        actions: List[chex.Array],
    ) -> Tuple[
        List[chex.Array],
        List[EnvState],
        List[chex.Array],
        List[chex.Array],
        List[Dict[str, Any]],
    ]:
        """
        Step all `M` vectorised environment groups simultaneously.

        Parameters
        ----------
        states : List[EnvState]
            Per-group batched states from a previous reset or step
        actions : List[chex.Array]
            Per-group batched actions

        Returns
        -------
        observations : List[chex.Array]
            Per-group batched observations after the step
        new_states : List[EnvState]
            Per-group updated batched states
        rewards : List[chex.Array]
            Per-group batched rewards
        dones : List[chex.Array]
            Per-group batched terminal flags
        infos : List[Dict[str, Any]]
            Per-group batched info dicts

        Raises
        ------
        length_mismatch : ValueError
            If `len(states)` or `len(actions)` does not match `num_envs`.
        """
        if len(states) != self.num_envs or len(actions) != self.num_envs:
            raise ValueError(
                f"MultiVecEnv.step: expected {self.num_envs} states and actions, "
                f"got {len(states)} states and {len(actions)} actions."
            )

        results = [
            vec.step(state, action)
            for vec, state, action in zip(self._vec_envs, states, actions)
        ]
        return (
            [r[0] for r in results],
            [r[1] for r in results],
            [r[2] for r in results],
            [r[3] for r in results],
            [r[4] for r in results],
        )

    def reset_at(self, idx: int, rng: chex.PRNGKey) -> Tuple[chex.Array, EnvState]:
        """
        Reset a single `VecEnv` group by index.

        Parameters
        ----------
        idx : int
            Index of the `VecEnv` group to reset
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : chex.Array
            Batched initial observations for this group
        state : EnvState
            Batched initial state for this group
        """
        return self._vec_envs[idx].reset(rng)

    def step_at(
        self,
        idx: int,
        state: EnvState,
        action: chex.Array,
    ) -> Tuple[chex.Array, EnvState, chex.Array, chex.Array, Dict[str, Any]]:
        """
        Step a single `VecEnv` group by index.

        Parameters
        ----------
        idx : int
            Index of the `VecEnv` group to step
        state : EnvState
            Batched state for this group
        action : chex.Array
            Batched action for this group

        Returns
        -------
        obs : chex.Array
            Batched observations after the step
        new_state : EnvState
            Updated batched state
        reward : chex.Array
            Batched rewards
        done : chex.Array
            Batched terminal flags
        info : Dict[str, Any]
            Batched info dict
        """
        return self._vec_envs[idx].step(state, action)

    def compile(self, *, progress: bool = True) -> None:
        """
        Trigger XLA compilation for all inner `VecEnv` instances.

        Calls `compile()` on each `VecEnv`, which runs a dummy
        `reset` + `step` to populate the XLA cache.

        Parameters
        ----------
        progress : bool (optional)
            Show a `tqdm` progress bar. Default is `True`.
        """
        it = (
            tqdm(self._vec_envs, desc="Compiling vec envs", unit="env")
            if progress
            else self._vec_envs
        )
        for vec in it:
            vec.compile()

    def __len__(self) -> int:
        return len(self._vec_envs)

    def __repr__(self) -> str:
        group_info = ", ".join(
            f"{type(v.env).__name__}×{v.num_envs}" for v in self._vec_envs
        )
        return f"MultiVecEnv([{group_info}], total={self.total_envs})"

num_envs property

Number of VecEnv groups (M).

total_envs property

Total number of individual environments across all groups.

vec_envs property

The inner VecEnv instances.

observation_spaces property

Per-group batched observation spaces.

action_spaces property

Per-group batched action spaces.

single_observation_spaces property

Per-group unbatched observation spaces.

single_action_spaces property

Per-group unbatched action spaces.

class_groups property

Inner env class name → list of VecEnv indices.

Useful for downstream code that wants to batch operations across groups sharing the same inner env class.

reset(rng)

Reset all M vectorised environment groups with independent PRNG keys.

Each VecEnv receives one sub-key and splits it internally across its own parallel copies.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
observations List[Array]

Per-group batched observations

states List[EnvState]

Per-group batched states

Source code in envrax/multi_vec_env.py
Python
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def reset(self, rng: chex.PRNGKey) -> Tuple[List[chex.Array], List[EnvState]]:
    """
    Reset all `M` vectorised environment groups with independent PRNG keys.

    Each `VecEnv` receives one sub-key and splits it internally across
    its own parallel copies.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    observations : List[chex.Array]
        Per-group batched observations
    states : List[EnvState]
        Per-group batched states
    """
    rngs = jax.random.split(rng, self.num_envs)
    obs_list: List[chex.Array] = []
    state_list: List[EnvState] = []

    for i, vec in enumerate(self._vec_envs):
        obs, state = vec.reset(rngs[i])
        obs_list.append(obs)
        state_list.append(state)

    return obs_list, state_list

step(states, actions)

Step all M vectorised environment groups simultaneously.

Parameters:

Name Type Description Default
states List[EnvState]

Per-group batched states from a previous reset or step

required
actions List[Array]

Per-group batched actions

required

Returns:

Name Type Description
observations List[Array]

Per-group batched observations after the step

new_states List[EnvState]

Per-group updated batched states

rewards List[Array]

Per-group batched rewards

dones List[Array]

Per-group batched terminal flags

infos List[Dict[str, Any]]

Per-group batched info dicts

Raises:

Name Type Description
length_mismatch ValueError

If len(states) or len(actions) does not match num_envs.

Source code in envrax/multi_vec_env.py
Python
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def step(
    self,
    states: List[EnvState],
    actions: List[chex.Array],
) -> Tuple[
    List[chex.Array],
    List[EnvState],
    List[chex.Array],
    List[chex.Array],
    List[Dict[str, Any]],
]:
    """
    Step all `M` vectorised environment groups simultaneously.

    Parameters
    ----------
    states : List[EnvState]
        Per-group batched states from a previous reset or step
    actions : List[chex.Array]
        Per-group batched actions

    Returns
    -------
    observations : List[chex.Array]
        Per-group batched observations after the step
    new_states : List[EnvState]
        Per-group updated batched states
    rewards : List[chex.Array]
        Per-group batched rewards
    dones : List[chex.Array]
        Per-group batched terminal flags
    infos : List[Dict[str, Any]]
        Per-group batched info dicts

    Raises
    ------
    length_mismatch : ValueError
        If `len(states)` or `len(actions)` does not match `num_envs`.
    """
    if len(states) != self.num_envs or len(actions) != self.num_envs:
        raise ValueError(
            f"MultiVecEnv.step: expected {self.num_envs} states and actions, "
            f"got {len(states)} states and {len(actions)} actions."
        )

    results = [
        vec.step(state, action)
        for vec, state, action in zip(self._vec_envs, states, actions)
    ]
    return (
        [r[0] for r in results],
        [r[1] for r in results],
        [r[2] for r in results],
        [r[3] for r in results],
        [r[4] for r in results],
    )

reset_at(idx, rng)

Reset a single VecEnv group by index.

Parameters:

Name Type Description Default
idx int

Index of the VecEnv group to reset

required
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Batched initial observations for this group

state EnvState

Batched initial state for this group

Source code in envrax/multi_vec_env.py
Python
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def reset_at(self, idx: int, rng: chex.PRNGKey) -> Tuple[chex.Array, EnvState]:
    """
    Reset a single `VecEnv` group by index.

    Parameters
    ----------
    idx : int
        Index of the `VecEnv` group to reset
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : chex.Array
        Batched initial observations for this group
    state : EnvState
        Batched initial state for this group
    """
    return self._vec_envs[idx].reset(rng)

step_at(idx, state, action)

Step a single VecEnv group by index.

Parameters:

Name Type Description Default
idx int

Index of the VecEnv group to step

required
state EnvState

Batched state for this group

required
action Array

Batched action for this group

required

Returns:

Name Type Description
obs Array

Batched observations after the step

new_state EnvState

Updated batched state

reward Array

Batched rewards

done Array

Batched terminal flags

info Dict[str, Any]

Batched info dict

Source code in envrax/multi_vec_env.py
Python
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def step_at(
    self,
    idx: int,
    state: EnvState,
    action: chex.Array,
) -> Tuple[chex.Array, EnvState, chex.Array, chex.Array, Dict[str, Any]]:
    """
    Step a single `VecEnv` group by index.

    Parameters
    ----------
    idx : int
        Index of the `VecEnv` group to step
    state : EnvState
        Batched state for this group
    action : chex.Array
        Batched action for this group

    Returns
    -------
    obs : chex.Array
        Batched observations after the step
    new_state : EnvState
        Updated batched state
    reward : chex.Array
        Batched rewards
    done : chex.Array
        Batched terminal flags
    info : Dict[str, Any]
        Batched info dict
    """
    return self._vec_envs[idx].step(state, action)

compile(*, progress=True)

Trigger XLA compilation for all inner VecEnv instances.

Calls compile() on each VecEnv, which runs a dummy reset + step to populate the XLA cache.

Parameters:

Name Type Description Default
progress bool

Show a tqdm progress bar. Default is True.

True
Source code in envrax/multi_vec_env.py
Python
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def compile(self, *, progress: bool = True) -> None:
    """
    Trigger XLA compilation for all inner `VecEnv` instances.

    Calls `compile()` on each `VecEnv`, which runs a dummy
    `reset` + `step` to populate the XLA cache.

    Parameters
    ----------
    progress : bool (optional)
        Show a `tqdm` progress bar. Default is `True`.
    """
    it = (
        tqdm(self._vec_envs, desc="Compiling vec envs", unit="env")
        if progress
        else self._vec_envs
    )
    for vec in it:
        vec.compile()