Skip to content

Factory Functions

The canonical entry points for instantiating registered environments as single, vectorised, or multi-environment objects.

We recommend using these when possible!

envrax.make.make(name, *, config=None, wrappers=None, jit_compile=True, pre_warm=True, cache_dir=DEFAULT_CACHE_DIR)

Create a single JaxEnv, optionally with wrappers applied.

Parameters:

Name Type Description Default
name str

Registered environment name

required
config EnvConfig

Environment configuration. Defaults to the registered default config.

None
wrappers List[WrapperType]

Wrapper classes or pre-configured factories applied innermost-first around the base env.

None
jit_compile bool

Wrap the env in JitWrapper. Default is True.

True
pre_warm bool

When jit_compile=True, run a dummy reset + step immediately to trigger XLA compilation. Set to False to defer compilation to the first real call or an explicit compile(). Default is True.

True
cache_dir Path | str | None

Directory for the persistent XLA compilation cache. Defaults to <cwd>/.jax_cache (override with the ENVRAX_CACHE_DIR environment variable). Pass None to disable.

DEFAULT_CACHE_DIR

Returns:

Name Type Description
env JaxEnv

Configured environment, wrapped in JitWrapper when jit_compile=True.

Raises:

Name Type Description
unknown_env ValueError

If name is not registered.

Source code in envrax/make.py
Python
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def make(
    name: str,
    *,
    config: EnvConfig | None = None,
    wrappers: List[WrapperType] | None = None,
    jit_compile: bool = True,
    pre_warm: bool = True,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> JaxEnv:
    """
    Create a single `JaxEnv`, optionally with wrappers applied.

    Parameters
    ----------
    name : str
        Registered environment name
    config : EnvConfig (optional)
        Environment configuration. Defaults to the registered default config.
    wrappers : List[WrapperType] (optional)
        Wrapper classes or pre-configured factories applied innermost-first
        around the base env.
    jit_compile : bool (optional)
        Wrap the env in `JitWrapper`. Default is `True`.
    pre_warm : bool (optional)
        When `jit_compile=True`, run a dummy `reset` + `step` immediately
        to trigger XLA compilation. Set to `False` to defer compilation
        to the first real call or an explicit `compile()`. Default is `True`.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache.
        Defaults to `<cwd>/.jax_cache` (override with the
        `ENVRAX_CACHE_DIR` environment variable). Pass `None` to disable.

    Returns
    -------
    env : JaxEnv
        Configured environment, wrapped in `JitWrapper` when `jit_compile=True`.

    Raises
    ------
    unknown_env : ValueError
        If `name` is not registered.
    """
    if name not in _REGISTRY:
        available = sorted(_REGISTRY.keys())
        raise ValueError(f"Unknown environment: {name!r}. Available: {available}")

    spec = _REGISTRY[name]
    resolved_config = config if config is not None else spec.default_config
    env: JaxEnv = spec.env_class(config=resolved_config)

    if wrappers:
        for w in wrappers:
            env = w(env)

    if jit_compile:
        env = JitWrapper(env, cache_dir=cache_dir, pre_warm=pre_warm)

    return env

envrax.make.make_vec(name, n_envs, *, config=None, wrappers=None, jit_compile=True, pre_warm=True, cache_dir=DEFAULT_CACHE_DIR)

Create a VecEnv with n_envs parallel environments.

Parameters:

Name Type Description Default
name str

Registered environment name

required
n_envs int

Number of parallel environments

required
config EnvConfig

Environment configuration. Defaults to the registered default config.

None
wrappers List[WrapperType]

Wrapper classes applied innermost-first. Applied before vectorisation.

None
jit_compile bool

Enable the XLA compilation cache. Default is True.

True
pre_warm bool

When jit_compile=True, run a dummy reset + step immediately. Set to False to defer to an explicit vec_env.compile() call. Default is True.

True
cache_dir Path | str | None

Directory for the persistent XLA compilation cache.

DEFAULT_CACHE_DIR

Returns:

Name Type Description
vec_env VecEnv

Vectorised environment.

Source code in envrax/make.py
Python
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def make_vec(
    name: str,
    n_envs: int,
    *,
    config: EnvConfig | None = None,
    wrappers: List[WrapperType] | None = None,
    jit_compile: bool = True,
    pre_warm: bool = True,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> VecEnv:
    """
    Create a `VecEnv` with `n_envs` parallel environments.

    Parameters
    ----------
    name : str
        Registered environment name
    n_envs : int
        Number of parallel environments
    config : EnvConfig (optional)
        Environment configuration. Defaults to the registered default config.
    wrappers : List[WrapperType] (optional)
        Wrapper classes applied innermost-first. Applied before vectorisation.
    jit_compile : bool (optional)
        Enable the XLA compilation cache. Default is `True`.
    pre_warm : bool (optional)
        When `jit_compile=True`, run a dummy `reset` + `step` immediately.
        Set to `False` to defer to an explicit `vec_env.compile()` call.
        Default is `True`.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache.

    Returns
    -------
    vec_env : VecEnv
        Vectorised environment.
    """
    if jit_compile:
        setup_cache(cache_dir)

    inner_env = make(
        name,
        config=config,
        wrappers=wrappers,
        jit_compile=False,
        cache_dir=None,
    )

    vec_env = VecEnv(inner_env, n_envs)

    if jit_compile and pre_warm:
        vec_env.compile(cache_dir=cache_dir)

    return vec_env

envrax.make.make_multi(envs, *, pre_warm=False)

Wrap JaxEnv instances into a MultiEnv.

Parameters:

Name Type Description Default
envs List[JaxEnv] | Dict[str, JaxEnv]

Envs to wrap. List form: keys derived from env.name with suffixes on duplicates. Dict form: keys used verbatim for explicit control.

required
pre_warm bool

Trigger compilation of any JitWrapper-wrapped inner envs immediately. Default is False — call multi_env.compile() later.

False

Returns:

Name Type Description
multi_env MultiEnv

Manager holding all JaxEnv instances.

Source code in envrax/make.py
Python
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def make_multi(
    envs: List[JaxEnv] | Dict[str, JaxEnv],
    *,
    pre_warm: bool = False,
) -> MultiEnv:
    """
    Wrap `JaxEnv` instances into a `MultiEnv`.

    Parameters
    ----------
    envs : List[JaxEnv] | Dict[str, JaxEnv]
        Envs to wrap. List form: keys derived from `env.name` with suffixes
        on duplicates. Dict form: keys used verbatim for explicit control.
    pre_warm : bool (optional)
        Trigger compilation of any `JitWrapper`-wrapped inner envs
        immediately. Default is `False` — call `multi_env.compile()` later.

    Returns
    -------
    multi_env : MultiEnv
        Manager holding all `JaxEnv` instances.
    """
    multi = MultiEnv(envs)

    if pre_warm:
        multi.compile()

    return multi

envrax.make.make_multi_vec(envs, *, jit_compile=True, pre_warm=False, cache_dir=DEFAULT_CACHE_DIR)

Wrap BatchedEnv instances into a MultiVecEnv.

Parameters:

Name Type Description Default
envs List[BatchedEnv] | Dict[str, BatchedEnv]

Envs to wrap. List form: keys derived from env.name with suffixes on duplicates. Dict form: keys used verbatim for explicit control.

required
jit_compile bool

Enable the XLA compilation cache. Default is True.

True
pre_warm bool

Compile each inner env and the multi-step jit immediately. Default is False — call multi_vec_env.compile() later.

False
cache_dir Path | str | None

Directory for the persistent XLA compilation cache.

DEFAULT_CACHE_DIR

Returns:

Name Type Description
multi_vec_env MultiVecEnv

Manager holding all BatchedEnv instances.

Source code in envrax/make.py
Python
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def make_multi_vec(
    envs: List[BatchedEnv] | Dict[str, BatchedEnv],
    *,
    jit_compile: bool = True,
    pre_warm: bool = False,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> MultiVecEnv:
    """
    Wrap `BatchedEnv` instances into a `MultiVecEnv`.

    Parameters
    ----------
    envs : List[BatchedEnv] | Dict[str, BatchedEnv]
        Envs to wrap. List form: keys derived from `env.name` with suffixes
        on duplicates. Dict form: keys used verbatim for explicit control.
    jit_compile : bool (optional)
        Enable the XLA compilation cache. Default is `True`.
    pre_warm : bool (optional)
        Compile each inner env and the multi-step jit immediately. Default
        is `False` — call `multi_vec_env.compile()` later.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache.

    Returns
    -------
    multi_vec_env : MultiVecEnv
        Manager holding all `BatchedEnv` instances.
    """
    if jit_compile:
        setup_cache(cache_dir)

    multi = MultiVecEnv(envs)

    if jit_compile and pre_warm:
        multi.compile(cache_dir=cache_dir)

    return multi