Skip to content

Stateful Wrappers

Wrappers that carry their own unique state across steps.

envrax.wrappers.frame_stack.FrameStackObservation

Bases: StatefulWrapper[Box, ActSpaceT, FrameStackState[InnerStateT], ConfigT, InnerStateT]

Maintain a sliding window of the last n_stack observations.

Expects the inner environment to produce uint8[H, W] observations. The stacked observation has shape uint8[H, W, n_stack].

Parameters:

Name Type Description Default
env JaxEnv

Inner environment returning 2-D uint8 observations.

required
n_stack int

Number of frames to stack. Default is 4.

required
Source code in envrax/wrappers/frame_stack.py
Python
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class FrameStackObservation(
    StatefulWrapper[
        Box,
        ActSpaceT,
        FrameStackState[InnerStateT],
        ConfigT,
        InnerStateT,
    ]
):
    """
    Maintain a sliding window of the last `n_stack` observations.

    Expects the inner environment to produce `uint8[H, W]` observations.
    The stacked observation has shape `uint8[H, W, n_stack]`.

    Parameters
    ----------
    env : JaxEnv
        Inner environment returning 2-D `uint8` observations.
    n_stack : int (optional)
        Number of frames to stack. Default is `4`.
    """

    def __init__(
        self,
        env: JaxEnv[Box, ActSpaceT, InnerStateT, ConfigT],
        *,
        n_stack: int = 4,
    ) -> None:
        super().__init__(env)
        require_box(env, type(self).__name__, rank=2, dtype=jnp.uint8)
        self._n_stack = n_stack

    def reset(
        self, rng: chex.PRNGKey
    ) -> Tuple[chex.Array, FrameStackState[InnerStateT]]:
        """
        Reset the inner environment and initialise the frame stack.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs  : chex.Array
            Initial stacked observation
        state : FrameStackState
            Wrapper state containing the inner state and the stack
        """
        obs, env_state = self._env.reset(rng)
        stack = jnp.stack([obs] * self._n_stack, axis=-1)
        wrapped = FrameStackState(
            rng=env_state.rng,
            step=env_state.step,
            done=env_state.done,
            env_state=env_state,
            obs_stack=stack,
        )
        return stack, wrapped

    def step(
        self,
        state: FrameStackState[InnerStateT],
        action: chex.Array,
    ) -> Tuple[
        chex.Array,
        FrameStackState[InnerStateT],
        chex.Array,
        chex.Array,
        Dict[str, Any],
    ]:
        """
        Step the inner environment and roll the frame stack.

        Parameters
        ----------
        state : FrameStackState
            Current wrapper state
        action : chex.Array
            Action to take in the environment

        Returns
        -------
        obs  : chex.Array
            Updated stacked observation
        new_state : FrameStackState
            Updated wrapper state
        reward  : chex.Array
            Reward from the inner step
        done  : chex.Array
            Terminal flag from the inner step
        info : Dict[str, Any]
            Info dict from the inner step
        """
        obs, env_state, reward, done, info = self._env.step(state.env_state, action)
        new_stack = jnp.concatenate(
            [state.obs_stack[..., 1:], jnp.expand_dims(obs, -1)],
            axis=-1,
        )
        new_state = FrameStackState(
            rng=env_state.rng,
            step=env_state.step,
            done=env_state.done,
            env_state=env_state,
            obs_stack=new_stack,
        )
        return new_stack, new_state, reward, done, info

    @property
    def observation_space(self) -> Box:
        inner = self._env.observation_space
        h, w = inner.shape[0], inner.shape[1]
        return Box(
            low=inner.low,
            high=inner.high,
            shape=(h, w, self._n_stack),
            dtype=inner.dtype,
        )

reset(rng)

Reset the inner environment and initialise the frame stack.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Initial stacked observation

state FrameStackState

Wrapper state containing the inner state and the stack

Source code in envrax/wrappers/frame_stack.py
Python
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def reset(
    self, rng: chex.PRNGKey
) -> Tuple[chex.Array, FrameStackState[InnerStateT]]:
    """
    Reset the inner environment and initialise the frame stack.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs  : chex.Array
        Initial stacked observation
    state : FrameStackState
        Wrapper state containing the inner state and the stack
    """
    obs, env_state = self._env.reset(rng)
    stack = jnp.stack([obs] * self._n_stack, axis=-1)
    wrapped = FrameStackState(
        rng=env_state.rng,
        step=env_state.step,
        done=env_state.done,
        env_state=env_state,
        obs_stack=stack,
    )
    return stack, wrapped

step(state, action)

Step the inner environment and roll the frame stack.

Parameters:

Name Type Description Default
state FrameStackState

Current wrapper state

required
action Array

Action to take in the environment

required

Returns:

Name Type Description
obs Array

Updated stacked observation

new_state FrameStackState

Updated wrapper state

reward Array

Reward from the inner step

done Array

Terminal flag from the inner step

info Dict[str, Any]

Info dict from the inner step

Source code in envrax/wrappers/frame_stack.py
Python
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def step(
    self,
    state: FrameStackState[InnerStateT],
    action: chex.Array,
) -> Tuple[
    chex.Array,
    FrameStackState[InnerStateT],
    chex.Array,
    chex.Array,
    Dict[str, Any],
]:
    """
    Step the inner environment and roll the frame stack.

    Parameters
    ----------
    state : FrameStackState
        Current wrapper state
    action : chex.Array
        Action to take in the environment

    Returns
    -------
    obs  : chex.Array
        Updated stacked observation
    new_state : FrameStackState
        Updated wrapper state
    reward  : chex.Array
        Reward from the inner step
    done  : chex.Array
        Terminal flag from the inner step
    info : Dict[str, Any]
        Info dict from the inner step
    """
    obs, env_state, reward, done, info = self._env.step(state.env_state, action)
    new_stack = jnp.concatenate(
        [state.obs_stack[..., 1:], jnp.expand_dims(obs, -1)],
        axis=-1,
    )
    new_state = FrameStackState(
        rng=env_state.rng,
        step=env_state.step,
        done=env_state.done,
        env_state=env_state,
        obs_stack=new_stack,
    )
    return new_stack, new_state, reward, done, info

envrax.wrappers.record_episode_statistics.RecordEpisodeStatistics

Bases: StatefulWrapper[ObsSpaceT, ActSpaceT, EpisodeStatisticsState[InnerStateT], ConfigT, InnerStateT]

Records episode return and length.

Accumulates reward and step count in EpisodeStatisticsState. Episode statistics are written to info["episode"] on every step() call; values are non-zero only when done=True.

Parameters:

Name Type Description Default
env JaxEnv

Environment to wrap.

required
Source code in envrax/wrappers/record_episode_statistics.py
Python
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class RecordEpisodeStatistics(
    StatefulWrapper[
        ObsSpaceT,
        ActSpaceT,
        EpisodeStatisticsState[InnerStateT],
        ConfigT,
        InnerStateT,
    ]
):
    """
    Records episode return and length.

    Accumulates reward and step count in `EpisodeStatisticsState`.
    Episode statistics are written to `info["episode"]` on every `step()`
    call; values are non-zero only when `done=True`.

    Parameters
    ----------
    env : JaxEnv
        Environment to wrap.
    """

    def __init__(self, env: JaxEnv[ObsSpaceT, ActSpaceT, InnerStateT, ConfigT]) -> None:
        super().__init__(env)

    def reset(
        self, rng: chex.PRNGKey
    ) -> Tuple[chex.Array, EpisodeStatisticsState[InnerStateT]]:
        """
        Reset the environment and episode accumulators.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs  : chex.Array
            Initial observation
        state : EpisodeStatisticsState
            Initial state with zeroed accumulators
        """
        obs, env_state = self._env.reset(rng)
        state = EpisodeStatisticsState(
            rng=env_state.rng,
            step=env_state.step,
            done=env_state.done,
            env_state=env_state,
            episode_return=jnp.float32(0.0),
            episode_length=jnp.int32(0),
        )
        return obs, state

    def step(
        self,
        state: EpisodeStatisticsState[InnerStateT],
        action: chex.Array,
    ) -> Tuple[
        chex.Array,
        EpisodeStatisticsState[InnerStateT],
        chex.Array,
        chex.Array,
        Dict[str, Any],
    ]:
        """
        Step the environment and update episode accumulators.

        Parameters
        ----------
        state : EpisodeStatisticsState
            Current state
        action : chex.Array
            Action to take in the environment

        Returns
        -------
        obs  : chex.Array
            Next observation
        new_state : EpisodeStatisticsState
            Updated state
        reward  : chex.Array
            Step reward
        done  : chex.Array
            Episode terminal flag
        info : Dict[str, Any]
            Environment metadata extended with `"episode"`:
            `{"r": float32, "l": int32}` — non-zero only when `done=True`
        """
        obs, env_state, reward, done, info = self._env.step(state.env_state, action)

        episode_return = state.episode_return + reward.astype(jnp.float32)
        episode_length = state.episode_length + jnp.int32(1)

        info["episode"] = {
            "r": jnp.where(done, episode_return, jnp.float32(0.0)),
            "l": jnp.where(done, episode_length, jnp.int32(0)),
        }

        new_state = EpisodeStatisticsState(
            rng=env_state.rng,
            step=env_state.step,
            done=env_state.done,
            env_state=env_state,
            episode_return=jnp.where(done, jnp.float32(0.0), episode_return),  # pyright: ignore[reportArgumentType]
            episode_length=jnp.where(done, jnp.int32(0), episode_length),
        )
        return obs, new_state, reward, done, info

reset(rng)

Reset the environment and episode accumulators.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Initial observation

state EpisodeStatisticsState

Initial state with zeroed accumulators

Source code in envrax/wrappers/record_episode_statistics.py
Python
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def reset(
    self, rng: chex.PRNGKey
) -> Tuple[chex.Array, EpisodeStatisticsState[InnerStateT]]:
    """
    Reset the environment and episode accumulators.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs  : chex.Array
        Initial observation
    state : EpisodeStatisticsState
        Initial state with zeroed accumulators
    """
    obs, env_state = self._env.reset(rng)
    state = EpisodeStatisticsState(
        rng=env_state.rng,
        step=env_state.step,
        done=env_state.done,
        env_state=env_state,
        episode_return=jnp.float32(0.0),
        episode_length=jnp.int32(0),
    )
    return obs, state

step(state, action)

Step the environment and update episode accumulators.

Parameters:

Name Type Description Default
state EpisodeStatisticsState

Current state

required
action Array

Action to take in the environment

required

Returns:

Name Type Description
obs Array

Next observation

new_state EpisodeStatisticsState

Updated state

reward Array

Step reward

done Array

Episode terminal flag

info Dict[str, Any]

Environment metadata extended with "episode": {"r": float32, "l": int32} — non-zero only when done=True

Source code in envrax/wrappers/record_episode_statistics.py
Python
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def step(
    self,
    state: EpisodeStatisticsState[InnerStateT],
    action: chex.Array,
) -> Tuple[
    chex.Array,
    EpisodeStatisticsState[InnerStateT],
    chex.Array,
    chex.Array,
    Dict[str, Any],
]:
    """
    Step the environment and update episode accumulators.

    Parameters
    ----------
    state : EpisodeStatisticsState
        Current state
    action : chex.Array
        Action to take in the environment

    Returns
    -------
    obs  : chex.Array
        Next observation
    new_state : EpisodeStatisticsState
        Updated state
    reward  : chex.Array
        Step reward
    done  : chex.Array
        Episode terminal flag
    info : Dict[str, Any]
        Environment metadata extended with `"episode"`:
        `{"r": float32, "l": int32}` — non-zero only when `done=True`
    """
    obs, env_state, reward, done, info = self._env.step(state.env_state, action)

    episode_return = state.episode_return + reward.astype(jnp.float32)
    episode_length = state.episode_length + jnp.int32(1)

    info["episode"] = {
        "r": jnp.where(done, episode_return, jnp.float32(0.0)),
        "l": jnp.where(done, episode_length, jnp.int32(0)),
    }

    new_state = EpisodeStatisticsState(
        rng=env_state.rng,
        step=env_state.step,
        done=env_state.done,
        env_state=env_state,
        episode_return=jnp.where(done, jnp.float32(0.0), episode_return),  # pyright: ignore[reportArgumentType]
        episode_length=jnp.where(done, jnp.int32(0), episode_length),
    )
    return obs, new_state, reward, done, info