Skip to content

Factory Functions

The canonical entry points for instantiating registered environments as single, vectorised, or multi-environment objects.

We recommend using these when possible!

envrax.make.make(name, *, config=None, wrappers=None, jit_compile=True, pre_warm=True, cache_dir=DEFAULT_CACHE_DIR)

Create a single JaxEnv, optionally with wrappers applied.

Parameters:

Name Type Description Default
name str

Registered environment name

required
config EnvConfig

Environment configuration. Defaults to the registered default config.

None
wrappers List[WrapperType]

Wrapper classes or pre-configured factories applied innermost-first around the base env.

None
jit_compile bool

Wrap the env in JitWrapper. Default is True.

True
pre_warm bool

When jit_compile=True, run a dummy reset + step immediately to trigger XLA compilation. Set to False to defer compilation to the first real call or an explicit compile(). Default is True.

True
cache_dir Path | str | None

Directory for the persistent XLA compilation cache. Defaults to ~/.cache/envrax/xla_cache. Pass None to disable.

DEFAULT_CACHE_DIR

Returns:

Name Type Description
env JaxEnv

Configured environment, wrapped in JitWrapper when jit_compile=True.

Raises:

Name Type Description
unknown_env ValueError

If name is not registered.

Source code in envrax/make.py
Python
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def make(
    name: str,
    *,
    config: EnvConfig | None = None,
    wrappers: List[WrapperType] | None = None,
    jit_compile: bool = True,
    pre_warm: bool = True,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> JaxEnv:
    """
    Create a single `JaxEnv`, optionally with wrappers applied.

    Parameters
    ----------
    name : str
        Registered environment name
    config : EnvConfig (optional)
        Environment configuration. Defaults to the registered default config.
    wrappers : List[WrapperType] (optional)
        Wrapper classes or pre-configured factories applied innermost-first
        around the base env.
    jit_compile : bool (optional)
        Wrap the env in `JitWrapper`. Default is `True`.
    pre_warm : bool (optional)
        When `jit_compile=True`, run a dummy `reset` + `step` immediately
        to trigger XLA compilation. Set to `False` to defer compilation
        to the first real call or an explicit `compile()`. Default is `True`.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache.
        Defaults to `~/.cache/envrax/xla_cache`. Pass `None` to disable.

    Returns
    -------
    env : JaxEnv
        Configured environment, wrapped in `JitWrapper` when `jit_compile=True`.

    Raises
    ------
    unknown_env : ValueError
        If `name` is not registered.
    """
    if name not in _REGISTRY:
        available = sorted(_REGISTRY.keys())
        raise ValueError(f"Unknown environment: {name!r}. Available: {available}")

    spec = _REGISTRY[name]
    resolved_config = config if config is not None else spec.default_config
    env: JaxEnv = spec.env_class(config=resolved_config)

    if wrappers:
        for w in wrappers:
            env = w(env)

    if jit_compile:
        env = JitWrapper(env, cache_dir=cache_dir, pre_warm=pre_warm)

    return env

envrax.make.make_vec(name, n_envs, *, config=None, wrappers=None, jit_compile=True, pre_warm=True, cache_dir=DEFAULT_CACHE_DIR)

Create a VecEnv with n_envs parallel environments.

Parameters:

Name Type Description Default
name str

Registered environment name

required
n_envs int

Number of parallel environments

required
config EnvConfig

Environment configuration. Defaults to the registered default config.

None
wrappers List[WrapperType]

Wrapper classes applied innermost-first. Applied before vectorisation.

None
jit_compile bool

Enable the XLA compilation cache. Default is True.

True
pre_warm bool

When jit_compile=True, run a dummy reset + step immediately. Set to False to defer to an explicit vec_env.compile() call. Default is True.

True
cache_dir Path | str | None

Directory for the persistent XLA compilation cache.

DEFAULT_CACHE_DIR

Returns:

Name Type Description
vec_env VecEnv

Vectorised environment.

Source code in envrax/make.py
Python
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def make_vec(
    name: str,
    n_envs: int,
    *,
    config: EnvConfig | None = None,
    wrappers: List[WrapperType] | None = None,
    jit_compile: bool = True,
    pre_warm: bool = True,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> VecEnv:
    """
    Create a `VecEnv` with `n_envs` parallel environments.

    Parameters
    ----------
    name : str
        Registered environment name
    n_envs : int
        Number of parallel environments
    config : EnvConfig (optional)
        Environment configuration. Defaults to the registered default config.
    wrappers : List[WrapperType] (optional)
        Wrapper classes applied innermost-first. Applied before vectorisation.
    jit_compile : bool (optional)
        Enable the XLA compilation cache. Default is `True`.
    pre_warm : bool (optional)
        When `jit_compile=True`, run a dummy `reset` + `step` immediately.
        Set to `False` to defer to an explicit `vec_env.compile()` call.
        Default is `True`.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache.

    Returns
    -------
    vec_env : VecEnv
        Vectorised environment.
    """
    inner_env = make(
        name,
        config=config,
        wrappers=wrappers,
        jit_compile=False,
        cache_dir=None,
    )

    vec_env = VecEnv(inner_env, n_envs)

    if jit_compile and pre_warm:
        vec_env.compile(cache_dir=cache_dir)

    return vec_env

envrax.make.make_multi(names, *, wrappers=None, jit_compile=True, pre_warm=False, cache_dir=DEFAULT_CACHE_DIR)

Create a MultiEnv managing M heterogeneous environments.

Each environment is constructed with its registered default config. For per-environment config overrides, register the variants ahead of time or compose manually with MultiEnv([make(name, config=...), ...]).

By default, pre_warm=False so environments are JIT-wrapped but not compiled immediately. Call multi_env.compile() to trigger compilation as a separate setup phase.

Parameters:

Name Type Description Default
names List[str]

Registered environment names

required
wrappers List[WrapperType]

Wrapper pipeline applied to every environment. Must be compatible with the observation and action spaces of every environment used.

None
jit_compile bool

Wrap each environment in JitWrapper. Default is True.

True
pre_warm bool

When jit_compile=True, compile each environment immediately on creation. Default is False — call multi_env.compile() later instead.

False
cache_dir Path | str | None

Directory for the persistent XLA compilation cache

DEFAULT_CACHE_DIR

Returns:

Name Type Description
multi_env MultiEnv

Manager holding all M environments

Source code in envrax/make.py
Python
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def make_multi(
    names: List[str],
    *,
    wrappers: List[WrapperType] | None = None,
    jit_compile: bool = True,
    pre_warm: bool = False,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> MultiEnv:
    """
    Create a `MultiEnv` managing M heterogeneous environments.

    Each environment is constructed with its registered default config. For per-environment
    config overrides, register the variants ahead of time or compose
    manually with `MultiEnv([make(name, config=...), ...])`.

    By default, `pre_warm=False` so environments are JIT-wrapped but not
    compiled immediately. Call `multi_env.compile()` to trigger compilation
    as a separate setup phase.

    Parameters
    ----------
    names : List[str]
        Registered environment names
    wrappers : List[WrapperType] (optional)
        Wrapper pipeline applied to every environment. Must be compatible with the
        observation and action spaces of every environment used.
    jit_compile : bool (optional)
        Wrap each environment in `JitWrapper`. Default is `True`.
    pre_warm : bool (optional)
        When `jit_compile=True`, compile each environment immediately on creation.
        Default is `False` — call `multi_env.compile()` later instead.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache

    Returns
    -------
    multi_env : MultiEnv
        Manager holding all M environments
    """
    envs = []
    for name in names:
        env = make(
            name,
            wrappers=wrappers,
            jit_compile=jit_compile,
            pre_warm=pre_warm,
            cache_dir=cache_dir,
        )
        envs.append(env)

    return MultiEnv(envs)

envrax.make.make_multi_vec(names, n_envs, *, wrappers=None, jit_compile=True, pre_warm=False, cache_dir=DEFAULT_CACHE_DIR)

Create a MultiVecEnv managing M heterogeneous vectorised environments.

Each environment is constructed with its registered default config. For per-environment config overrides, register the variants ahead of time or compose manually with MultiVecEnv([VecEnv(make(name, config=...), n), ...]).

By default, pre_warm=False so VecEnv instances are created but not compiled immediately. Call multi_vec_env.compile() to trigger compilation as a separate setup phase.

Parameters:

Name Type Description Default
names List[str]

Registered environment names

required
n_envs int

Number of parallel copies per environment

required
wrappers List[WrapperType]

Wrapper pipeline applied to every inner environment before vectorisation. Must be compatible with the observation and action spaces of every environment used.

None
jit_compile bool

Enable the XLA compilation cache. Default is True.

True
pre_warm bool

When jit_compile=True, compile each VecEnv immediately. Default is False — call multi_vec_env.compile() later instead.

False
cache_dir Path | str | None

Directory for the persistent XLA compilation cache

DEFAULT_CACHE_DIR

Returns:

Name Type Description
multi_vec_env MultiVecEnv

Manager holding all M vectorised environments

Source code in envrax/make.py
Python
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def make_multi_vec(
    names: List[str],
    n_envs: int,
    *,
    wrappers: List[WrapperType] | None = None,
    jit_compile: bool = True,
    pre_warm: bool = False,
    cache_dir: Path | str | None = DEFAULT_CACHE_DIR,
) -> MultiVecEnv:
    """
    Create a `MultiVecEnv` managing M heterogeneous vectorised environments.

    Each environment is constructed with its registered default config. For per-environment
    config overrides, register the variants ahead of time or compose
    manually with `MultiVecEnv([VecEnv(make(name, config=...), n), ...])`.

    By default, `pre_warm=False` so VecEnv instances are created but not
    compiled immediately. Call `multi_vec_env.compile()` to trigger
    compilation as a separate setup phase.

    Parameters
    ----------
    names : List[str]
        Registered environment names
    n_envs : int
        Number of parallel copies per environment
    wrappers : List[WrapperType] (optional)
        Wrapper pipeline applied to every inner environment before vectorisation.
        Must be compatible with the observation and action spaces of every
        environment used.
    jit_compile : bool (optional)
        Enable the XLA compilation cache. Default is `True`.
    pre_warm : bool (optional)
        When `jit_compile=True`, compile each VecEnv immediately.
        Default is `False` — call `multi_vec_env.compile()` later instead.
    cache_dir : Path | str | None (optional)
        Directory for the persistent XLA compilation cache

    Returns
    -------
    multi_vec_env : MultiVecEnv
        Manager holding all M vectorised environments
    """
    vec_envs = []
    for name in names:
        inner = make(
            name,
            wrappers=wrappers,
            jit_compile=False,
            cache_dir=None,
        )
        vec = VecEnv(inner, n_envs)

        if jit_compile and pre_warm:
            vec.compile(cache_dir=cache_dir)

        vec_envs.append(vec)

    return MultiVecEnv(vec_envs)