Skip to content

Base Wrappers

The abstract base classes every wrapper inherits from. Inherit from these to create your own!

envrax.wrappers.base.Wrapper

Bases: JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT]

Abstract base class for pass-through JaxEnv wrappers.

Pass-through wrappers preserve the inner env's state type unchanged. They declare four TypeVars:

Text Only
1
class ClipReward(Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]): ...

For wrappers that introduce their own outer state type wrapping the inner state, use StatefulWrapper instead.

The observation_space and action_space properties delegate to the inner environment by default and may be overridden when the wrapper changes the observation shape or action set.

Parameterised wrappers support a factory mode: calling the class without an env (using only keyword arguments) returns a _WrapperFactory rather than a live wrapper.

Parameters:

Name Type Description Default
env JaxEnv

Inner environment to wrap.

required
Source code in envrax/wrappers/base.py
Python
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class Wrapper(JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT]):
    """
    Abstract base class for pass-through JaxEnv wrappers.

    Pass-through wrappers preserve the inner env's state type unchanged.
    They declare four TypeVars:

        class ClipReward(Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT]): ...

    For wrappers that introduce their own outer state type wrapping the
    inner state, use `StatefulWrapper` instead.

    The `observation_space` and `action_space` properties delegate to the
    inner environment by default and may be overridden when the wrapper
    changes the observation shape or action set.

    Parameterised wrappers support a **factory mode**: calling the class
    without an `env` (using only keyword arguments) returns a
    `_WrapperFactory` rather than a live wrapper.

    Parameters
    ----------
    env : JaxEnv
        Inner environment to wrap.
    """

    @overload
    def __new__(cls, env: None = ..., **kwargs) -> "_WrapperFactory": ...

    @overload
    def __new__(cls, env: JaxEnv, **kwargs) -> Self: ...

    def __new__(cls, env=None, **kwargs):
        if env is None:
            factory = object.__new__(_WrapperFactory)
            _WrapperFactory.__init__(factory, cls, **kwargs)
            return factory
        return super().__new__(cls)

    def __init__(self, env: JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT]) -> None:
        super().__init__(env.config)
        self._env: JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT] = env

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}<{self._env!r}>"

    @abstractmethod
    def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, StateT]:
        """Reset the environment and return the initial observation and state."""
        raise NotImplementedError()

    @abstractmethod
    def step(
        self,
        state: StateT,
        action: chex.Array,
    ) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
        """Advance the environment by one step."""
        raise NotImplementedError()

    @property
    def unwrapped(self) -> JaxEnv:
        """Return the innermost `JaxEnv` by delegating through the wrapper chain."""
        return self._env.unwrapped if isinstance(self._env, Wrapper) else self._env

    def render(self, state: StateT, **kwargs: Any) -> np.ndarray:
        """Forward render to the inner environment."""
        return self._env.render(state, **kwargs)

    @property
    def observation_space(self) -> ObsSpaceT:
        """Observation space of the inner environment."""
        return self._env.observation_space

    @property
    def action_space(self) -> ActSpaceT:
        """Action space of the inner environment."""
        return self._env.action_space

unwrapped property

Return the innermost JaxEnv by delegating through the wrapper chain.

observation_space property

Observation space of the inner environment.

action_space property

Action space of the inner environment.

reset(rng) abstractmethod

Reset the environment and return the initial observation and state.

Source code in envrax/wrappers/base.py
Python
96
97
98
99
@abstractmethod
def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, StateT]:
    """Reset the environment and return the initial observation and state."""
    raise NotImplementedError()

step(state, action) abstractmethod

Advance the environment by one step.

Source code in envrax/wrappers/base.py
Python
101
102
103
104
105
106
107
108
@abstractmethod
def step(
    self,
    state: StateT,
    action: chex.Array,
) -> Tuple[chex.Array, StateT, chex.Array, chex.Array, Dict[str, Any]]:
    """Advance the environment by one step."""
    raise NotImplementedError()

render(state, **kwargs)

Forward render to the inner environment.

Source code in envrax/wrappers/base.py
Python
115
116
117
def render(self, state: StateT, **kwargs: Any) -> np.ndarray:
    """Forward render to the inner environment."""
    return self._env.render(state, **kwargs)

envrax.wrappers.base.StatefulWrapper

Bases: Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT], Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT, InnerStateT]

Abstract base class for stateful JaxEnv wrappers.

Stateful wrappers introduce their own outer state type that wraps the inner env's state. They declare five TypeVars — pinning StateT to their wrapper-specific class and leaving InnerStateT parametric:

Text Only
1
2
3
class FrameStackObservation(
    StatefulWrapper[Box, ActSpaceT, FrameStackState[InnerStateT], ConfigT, InnerStateT]
): ...

For wrappers that preserve the inner state unchanged, use Wrapper instead.

Parameters:

Name Type Description Default
env JaxEnv

Inner environment to wrap.

required
Source code in envrax/wrappers/base.py
Python
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class StatefulWrapper(
    Wrapper[ObsSpaceT, ActSpaceT, StateT, ConfigT],
    Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT, InnerStateT],
):
    """
    Abstract base class for stateful JaxEnv wrappers.

    Stateful wrappers introduce their own outer state type that wraps the
    inner env's state. They declare five TypeVars — pinning `StateT` to
    their wrapper-specific class and leaving `InnerStateT` parametric:

        class FrameStackObservation(
            StatefulWrapper[Box, ActSpaceT, FrameStackState[InnerStateT], ConfigT, InnerStateT]
        ): ...

    For wrappers that preserve the inner state unchanged, use `Wrapper`
    instead.

    Parameters
    ----------
    env : JaxEnv
        Inner environment to wrap.
    """

    def __init__(self, env: JaxEnv[ObsSpaceT, ActSpaceT, InnerStateT, ConfigT]) -> None:
        JaxEnv.__init__(self, env.config)
        self._env: JaxEnv[ObsSpaceT, ActSpaceT, InnerStateT, ConfigT] = env  # type: ignore[assignment]