Skip to content

Single Environment

The foundational classes for building a single JAX-native RL environment from scratch.

Types

envrax.env.ObsSpaceT = TypeVar('ObsSpaceT', bound=Space) module-attribute

Observation space generic type.

envrax.env.ActSpaceT = TypeVar('ActSpaceT', bound=Space) module-attribute

Action space generic type.

envrax.env.StateT = TypeVar('StateT', bound='EnvState') module-attribute

Environment state generic type.

envrax.env.ConfigT = TypeVar('ConfigT', bound='EnvConfig') module-attribute

Environment config generic type.

envrax.env.EnvState

Base environment state. Every environment extends this with its own fields.

All fields must be JAX arrays or Python scalars (for static shape info). No Python objects, no lists, no dicts.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required
step Array

Current timestep within the episode

required
done Array

bool scalar — episode termination flag

required
Source code in envrax/env.py
Python
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@chex.dataclass
class EnvState:
    """
    Base environment state. Every environment extends this with its own fields.

    All fields must be JAX arrays or Python scalars (for static shape info).
    No Python objects, no lists, no dicts.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key
    step : chex.Array
        Current timestep within the episode
    done : chex.Array
        bool scalar — episode termination flag
    """

    rng: chex.PRNGKey
    step: chex.Array
    done: chex.Array

envrax.env.EnvConfig

Static environment configuration. Set once at construction, never changed. Controls things like max steps, reward scaling, difficulty, etc.

Parameters:

Name Type Description Default
max_steps int

Maximum number of steps per episode. Default is 1000.

required
Source code in envrax/env.py
Python
39
40
41
42
43
44
45
46
47
48
49
50
51
@chex.dataclass
class EnvConfig:
    """
    Static environment configuration. Set once at construction, never changed.
    Controls things like max steps, reward scaling, difficulty, etc.

    Parameters
    ----------
    max_steps : int
        Maximum number of steps per episode. Default is 1000.
    """

    max_steps: int = 1000

envrax.env.JaxEnv

Bases: ABC, Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]

Base class for all JAX-native environments.

Generic over the observation space, action space, state, and config types so that subclasses and wrappers get accurate type info without runtime casts. Subclasses pin all four:

Text Only
1
class BallEnv(JaxEnv[Box, Discrete, BallState, BallConfig]): ...

Parameters:

Name Type Description Default
config ConfigT

Static environment configuration. Defaults to ConfigT().

required
Source code in envrax/env.py
Python
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class JaxEnv(ABC, Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]):
    """
    Base class for all JAX-native environments.

    Generic over the observation space, action space, state, and config
    types so that subclasses and wrappers get accurate type info without
    runtime casts. Subclasses pin all four:

        class BallEnv(JaxEnv[Box, Discrete, BallState, BallConfig]): ...

    Parameters
    ----------
    config : ConfigT (optional)
        Static environment configuration. Defaults to `ConfigT()`.
    """

    config: ConfigT

    def __init__(self, config: ConfigT | None = None) -> None:
        if config is None:
            config_cls = self._resolve_config_cls()
            config = config_cls()

        self.config = config  # type: ignore

    @property
    @abstractmethod
    def observation_space(self) -> ObsSpaceT:
        """Returns the observation space."""
        ...

    @property
    @abstractmethod
    def action_space(self) -> ActSpaceT:
        """Returns the action space."""
        ...

    @abstractmethod
    def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, StateT]:
        """
        Set the environment to a starting state.

        Implementations should split `rng` so one half is consumed for
        initialisation and the other half is stored on the returned state's
        `rng` field for `step` to use.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : chex.Array
            Initial observation
        state : StateT
            Initial environment state with `rng` embedded
        """
        ...

    @abstractmethod
    def step(
        self,
        state: StateT,
        action: chex.Array,
    ) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
        """
        Take an action through the environment.

        Implementations should split `state.rng` for any per-step randomness
        and store the remaining key on `new_state.rng` so randomness threads
        through the episode.

        Parameters
        ----------
        state : StateT
            Current environment state
        action : chex.Array
            Action to take in the environment

        Returns
        -------
        obs : chex.Array
            Observation after the step
        new_state : StateT
            Updated environment state
        reward : chex.Array
            Scalar reward
        done : chex.Array
            bool scalar — `True` when the episode has ended, `False` otherwise
        info : Dict[str, Any]
            Auxiliary diagnostic information
        """
        ...

    def render(self, state: StateT) -> np.ndarray:
        """
        Render the environment state as an RGB frame.

        Parameters
        ----------
        state : StateT
            Current environment state to render

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(H, W, 3)`
        """
        raise NotImplementedError(
            f"{type(self).__name__} does not implement render(). "
            "Override render(state) to return a uint8 (H, W, 3) RGB frame."
        )

    @classmethod
    def _resolve_config_cls(cls) -> Type:
        """
        Return the concrete `EnvConfig` subclass pinned via `JaxEnv[..., ConfigT]`.

        Returns
        -------
        config_cls : Type
            The class pinned to `ConfigT` for this subclass
        """
        return resolve_generic_arg(cls, JaxEnv, position=3)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}"

observation_space abstractmethod property

Returns the observation space.

action_space abstractmethod property

Returns the action space.

reset(rng) abstractmethod

Set the environment to a starting state.

Implementations should split rng so one half is consumed for initialisation and the other half is stored on the returned state's rng field for step to use.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Initial observation

state StateT

Initial environment state with rng embedded

Source code in envrax/env.py
Python
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@abstractmethod
def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, StateT]:
    """
    Set the environment to a starting state.

    Implementations should split `rng` so one half is consumed for
    initialisation and the other half is stored on the returned state's
    `rng` field for `step` to use.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : chex.Array
        Initial observation
    state : StateT
        Initial environment state with `rng` embedded
    """
    ...

step(state, action) abstractmethod

Take an action through the environment.

Implementations should split state.rng for any per-step randomness and store the remaining key on new_state.rng so randomness threads through the episode.

Parameters:

Name Type Description Default
state StateT

Current environment state

required
action Array

Action to take in the environment

required

Returns:

Name Type Description
obs Array

Observation after the step

new_state StateT

Updated environment state

reward Array

Scalar reward

done Array

bool scalar — True when the episode has ended, False otherwise

info Dict[str, Any]

Auxiliary diagnostic information

Source code in envrax/env.py
Python
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
@abstractmethod
def step(
    self,
    state: StateT,
    action: chex.Array,
) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
    """
    Take an action through the environment.

    Implementations should split `state.rng` for any per-step randomness
    and store the remaining key on `new_state.rng` so randomness threads
    through the episode.

    Parameters
    ----------
    state : StateT
        Current environment state
    action : chex.Array
        Action to take in the environment

    Returns
    -------
    obs : chex.Array
        Observation after the step
    new_state : StateT
        Updated environment state
    reward : chex.Array
        Scalar reward
    done : chex.Array
        bool scalar — `True` when the episode has ended, `False` otherwise
    info : Dict[str, Any]
        Auxiliary diagnostic information
    """
    ...

render(state)

Render the environment state as an RGB frame.

Parameters:

Name Type Description Default
state StateT

Current environment state to render

required

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (H, W, 3)

Source code in envrax/env.py
Python
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def render(self, state: StateT) -> np.ndarray:
    """
    Render the environment state as an RGB frame.

    Parameters
    ----------
    state : StateT
        Current environment state to render

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(H, W, 3)`
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not implement render(). "
        "Override render(state) to return a uint8 (H, W, 3) RGB frame."
    )