Skip to content

Batched Environment

Base class for any environment that produces N independent agent results per step.

envrax.batched_env.BatchedEnv

Bases: ABC

Base class for an env that produces n_slots independent agent results per step.

Implementations choose their batching strategy — jax.vmap over a single JaxEnv, a composite multi-agent scene, threaded backends, or any other approach. All implementations expose the same shape contract:

  • observations (n_slots, *obs_shape)
  • actions (n_slots, *action_shape)
  • rewards (n_slots,)
  • dones (n_slots,)

State is an implementation-specific pytree carried opaquely by the caller.

Source code in envrax/batched_env.py
Python
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class BatchedEnv(ABC):
    """
    Base class for an env that produces `n_slots` independent agent results per step.

    Implementations choose their batching strategy — `jax.vmap` over a single
    `JaxEnv`, a composite multi-agent scene, threaded backends, or any other
    approach. All implementations expose the same shape contract:

      * observations `(n_slots, *obs_shape)`
      * actions      `(n_slots, *action_shape)`
      * rewards      `(n_slots,)`
      * dones        `(n_slots,)`

    State is an implementation-specific pytree carried opaquely by the caller.
    """

    n_slots: int

    @property
    def name(self) -> str:
        """
        Default key used by `MultiVecEnv` when keys aren't supplied.

        Subclasses should override with something meaningful (e.g. the
        wrapped env's class name). Falls back to this `BatchedEnv` subclass
        name when not overridden.

        Returns
        -------
        name : str
            Short identifier for this batched env.
        """
        return type(self).__name__

    @property
    @abstractmethod
    def single_observation_space(self) -> Space:
        """Observation space of a single slot (unbatched)."""

    @property
    @abstractmethod
    def single_action_space(self) -> Space:
        """Action space of a single slot (unbatched)."""

    @abstractmethod
    def reset(self, rng: chex.PRNGKey) -> Tuple[jax.Array, Any]:
        """
        Reset all `n_slots` independent agents.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : jax.Array
            Initial observations, shape `(n_slots, *obs_shape)`
        state : Any
            Batched state pytree
        """

    @abstractmethod
    def step(
        self,
        state: Any,
        actions: jax.Array,
    ) -> Tuple[jax.Array, Any, jax.Array, jax.Array, Dict[str, Any]]:
        """
        Step all `n_slots` agents independently with per-slot auto-reset.

        Parameters
        ----------
        state : Any
            Batched state from a previous reset or step
        actions : jax.Array
            Actions, shape `(n_slots, *action_shape)`

        Returns
        -------
        obs : jax.Array
            Observations after the step, shape `(n_slots, *obs_shape)`
        new_state : Any
            Updated batched state
        reward : jax.Array
            Per-slot rewards, shape `(n_slots,)`
        done : jax.Array
            Per-slot terminal flags, shape `(n_slots,)`
        info : Dict[str, Any]
            Batched info dict
        """

    @abstractmethod
    def slot_state(self, state: Any, slot_idx: int) -> Any:
        """
        Extract the single-slot state pytree for one agent.

        Parameters
        ----------
        state : Any
            Batched state
        slot_idx : int
            Slot index in `[0, n_slots)`

        Returns
        -------
        single_state : Any
            Pytree of the same structure as `state` but with leading slot
            dimension removed.
        """

    @abstractmethod
    def render_slot(self, state: Any, slot_idx: int) -> np.ndarray:
        """
        Render a single slot from the batched state as an RGB frame.

        Parameters
        ----------
        state : Any
            Batched state
        slot_idx : int
            Slot index in `[0, n_slots)`

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(H, W, 3)`
        """

    @abstractmethod
    def compile(self, cache_dir: Path | str | None = None) -> None:
        """
        Trigger XLA compilation by running dummy `reset` + `step` calls.

        Implementations should also warm any conditional branches (e.g. the
        reset path of auto-reset logic) so the persistent cache covers them.

        Parameters
        ----------
        cache_dir : Path | str | None (optional)
            XLA cache directory. Implementations may default to a stable
            project-relative path.
        """

name property

Default key used by MultiVecEnv when keys aren't supplied.

Subclasses should override with something meaningful (e.g. the wrapped env's class name). Falls back to this BatchedEnv subclass name when not overridden.

Returns:

Name Type Description
name str

Short identifier for this batched env.

single_observation_space abstractmethod property

Observation space of a single slot (unbatched).

single_action_space abstractmethod property

Action space of a single slot (unbatched).

reset(rng) abstractmethod

Reset all n_slots independent agents.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Initial observations, shape (n_slots, *obs_shape)

state Any

Batched state pytree

Source code in envrax/batched_env.py
Python
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@abstractmethod
def reset(self, rng: chex.PRNGKey) -> Tuple[jax.Array, Any]:
    """
    Reset all `n_slots` independent agents.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : jax.Array
        Initial observations, shape `(n_slots, *obs_shape)`
    state : Any
        Batched state pytree
    """

step(state, actions) abstractmethod

Step all n_slots agents independently with per-slot auto-reset.

Parameters:

Name Type Description Default
state Any

Batched state from a previous reset or step

required
actions Array

Actions, shape (n_slots, *action_shape)

required

Returns:

Name Type Description
obs Array

Observations after the step, shape (n_slots, *obs_shape)

new_state Any

Updated batched state

reward Array

Per-slot rewards, shape (n_slots,)

done Array

Per-slot terminal flags, shape (n_slots,)

info Dict[str, Any]

Batched info dict

Source code in envrax/batched_env.py
Python
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@abstractmethod
def step(
    self,
    state: Any,
    actions: jax.Array,
) -> Tuple[jax.Array, Any, jax.Array, jax.Array, Dict[str, Any]]:
    """
    Step all `n_slots` agents independently with per-slot auto-reset.

    Parameters
    ----------
    state : Any
        Batched state from a previous reset or step
    actions : jax.Array
        Actions, shape `(n_slots, *action_shape)`

    Returns
    -------
    obs : jax.Array
        Observations after the step, shape `(n_slots, *obs_shape)`
    new_state : Any
        Updated batched state
    reward : jax.Array
        Per-slot rewards, shape `(n_slots,)`
    done : jax.Array
        Per-slot terminal flags, shape `(n_slots,)`
    info : Dict[str, Any]
        Batched info dict
    """

slot_state(state, slot_idx) abstractmethod

Extract the single-slot state pytree for one agent.

Parameters:

Name Type Description Default
state Any

Batched state

required
slot_idx int

Slot index in [0, n_slots)

required

Returns:

Name Type Description
single_state Any

Pytree of the same structure as state but with leading slot dimension removed.

Source code in envrax/batched_env.py
Python
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@abstractmethod
def slot_state(self, state: Any, slot_idx: int) -> Any:
    """
    Extract the single-slot state pytree for one agent.

    Parameters
    ----------
    state : Any
        Batched state
    slot_idx : int
        Slot index in `[0, n_slots)`

    Returns
    -------
    single_state : Any
        Pytree of the same structure as `state` but with leading slot
        dimension removed.
    """

render_slot(state, slot_idx) abstractmethod

Render a single slot from the batched state as an RGB frame.

Parameters:

Name Type Description Default
state Any

Batched state

required
slot_idx int

Slot index in [0, n_slots)

required

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (H, W, 3)

Source code in envrax/batched_env.py
Python
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@abstractmethod
def render_slot(self, state: Any, slot_idx: int) -> np.ndarray:
    """
    Render a single slot from the batched state as an RGB frame.

    Parameters
    ----------
    state : Any
        Batched state
    slot_idx : int
        Slot index in `[0, n_slots)`

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(H, W, 3)`
    """

compile(cache_dir=None) abstractmethod

Trigger XLA compilation by running dummy reset + step calls.

Implementations should also warm any conditional branches (e.g. the reset path of auto-reset logic) so the persistent cache covers them.

Parameters:

Name Type Description Default
cache_dir Path | str | None

XLA cache directory. Implementations may default to a stable project-relative path.

None
Source code in envrax/batched_env.py
Python
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@abstractmethod
def compile(self, cache_dir: Path | str | None = None) -> None:
    """
    Trigger XLA compilation by running dummy `reset` + `step` calls.

    Implementations should also warm any conditional branches (e.g. the
    reset path of auto-reset logic) so the persistent cache covers them.

    Parameters
    ----------
    cache_dir : Path | str | None (optional)
        XLA cache directory. Implementations may default to a stable
        project-relative path.
    """