Skip to content

Factory Helpers

Internal factory utilities used to construct wrappers with consistent type inference.

envrax.wrappers.base._WrapperFactory

Deferred wrapper returned by Wrapper.__new__ when called without an env.

Calling the factory with an env creates the intended wrapper with the pre-bound keyword arguments.

Parameters:

Name Type Description Default
cls type

Concrete Wrapper subclass to instantiate.

required
**kwargs

Keyword arguments forwarded to cls.__init__ when the factory is called.

required
Source code in envrax/wrappers/base.py
Python
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class _WrapperFactory:
    """
    Deferred wrapper returned by `Wrapper.__new__` when called without an `env`.

    Calling the factory with an `env` creates the intended wrapper with the
    pre-bound keyword arguments.

    Parameters
    ----------
    cls : type
        Concrete `Wrapper` subclass to instantiate.
    **kwargs
        Keyword arguments forwarded to `cls.__init__` when the factory is called.
    """

    __slots__ = ("_cls", "_kwargs")

    def __init__(self, cls: type, **kwargs) -> None:
        self._cls = cls
        self._kwargs = kwargs

    def __call__(self, env: JaxEnv) -> "Wrapper":
        """
        Wrap `env` using the stored class and keyword arguments.

        Parameters
        ----------
        env : JaxEnv
            Environment to wrap.

        Returns
        -------
        wrapper : Wrapper
            Configured wrapper instance.
        """
        return self._cls(env, **self._kwargs)

__call__(env)

Wrap env using the stored class and keyword arguments.

Parameters:

Name Type Description Default
env JaxEnv

Environment to wrap.

required

Returns:

Name Type Description
wrapper Wrapper

Configured wrapper instance.

Source code in envrax/wrappers/base.py
Python
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __call__(self, env: JaxEnv) -> "Wrapper":
    """
    Wrap `env` using the stored class and keyword arguments.

    Parameters
    ----------
    env : JaxEnv
        Environment to wrap.

    Returns
    -------
    wrapper : Wrapper
        Configured wrapper instance.
    """
    return self._cls(env, **self._kwargs)