Skip to content

Suite

Types for grouping related environments into reusable, versioned collections.

envrax.suite.EnvSpec dataclass

Specification for a single environment — the unit of registration.

Holds everything needed to instantiate a registered environment: its name, the class to construct, and the default config to pass. Used both as a definition-time artifact (inside EnvSuite.specs) and as the runtime value stored in the registry.

Parameters:

Name Type Description Default
name str

Short name within an EnvSuite (e.g. "cartpole") at definition time, or canonical ID (e.g. "mjx/cartpole-v0") once registered.

required
env_class Type[JaxEnv]

Environment class to instantiate.

required
default_config EnvConfig

Default configuration passed to env_class(config=...).

required
suite str

Suite category tag (e.g. "MuJoCo Playground"). Populated by register_suite from the parent EnvSuite.category.

required
Source code in envrax/suite.py
Python
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass(frozen=True)
class EnvSpec:
    """
    Specification for a single environment — the unit of registration.

    Holds everything needed to instantiate a registered environment: its
    name, the class to construct, and the default config to pass. Used both
    as a definition-time artifact (inside `EnvSuite.specs`) and as the
    runtime value stored in the registry.

    Parameters
    ----------
    name : str
        Short name within an `EnvSuite` (e.g. `"cartpole"`) at definition
        time, or canonical ID (e.g. `"mjx/cartpole-v0"`) once registered.
    env_class : Type[JaxEnv]
        Environment class to instantiate.
    default_config : EnvConfig
        Default configuration passed to `env_class(config=...)`.
    suite : str
        Suite category tag (e.g. `"MuJoCo Playground"`). Populated by
        `register_suite` from the parent `EnvSuite.category`.
    """

    name: str
    env_class: Type[JaxEnv]
    default_config: EnvConfig
    suite: str = ""

envrax.suite.EnvSuite dataclass

A named, versioned collection of environment specs from one suite.

The specs list is the single source of truth for which environments the suite ships. The envs property derives short names from specs so that iteration, slicing, and display work without a parallel list.

Subclasses pin prefix, category, version, required_packages and provide their specs via default_factory. They must also override get_name to produce canonical IDs (e.g. "mjx/cartpole-v0").

Parameters:

Name Type Description Default
prefix str

Namespace prefix for environment names (e.g. "mjx").

required
category str

Human-readable category label (e.g. "MuJoCo Playground").

required
version str

Version suffix applied by get_name (e.g. "v0"). Default is v0

required
required_packages List[str]

Python packages that must be importable for this suite to work.

required
specs List[EnvSpec]

Environment specifications shipped by this suite.

required
Source code in envrax/suite.py
Python
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@dataclass
class EnvSuite:
    """
    A named, versioned collection of environment specs from one suite.

    The `specs` list is the single source of truth for which environments
    the suite ships. The `envs` property derives short names from `specs`
    so that iteration, slicing, and display work without a parallel list.

    Subclasses pin `prefix`, `category`, `version`, `required_packages`
    and provide their `specs` via `default_factory`. They must also
    override `get_name` to produce canonical IDs (e.g. `"mjx/cartpole-v0"`).

    Parameters
    ----------
    prefix : str
        Namespace prefix for environment names (e.g. `"mjx"`).
    category : str
        Human-readable category label (e.g. `"MuJoCo Playground"`).
    version : str (optional)
        Version suffix applied by `get_name` (e.g. `"v0"`). Default is `v0`
    required_packages : List[str]
        Python packages that must be importable for this suite to work.
    specs : List[EnvSpec]
        Environment specifications shipped by this suite.
    """

    prefix: str = ""
    category: str = ""
    version: str = "v0"
    required_packages: List[str] = field(default_factory=list)
    specs: List[EnvSpec] = field(default_factory=list)

    @property
    def envs(self) -> List[str]:
        """Short names of all environments in this suite, derived from specs."""
        return [s.name for s in self.specs]

    @property
    def n_envs(self) -> int:
        """Number of environments in this suite."""
        return len(self.specs)

    def get_name(self, name: str, version: str | None = None) -> str:
        """
        Return the canonical ID string for a single environment.

        Parameters
        ----------
        name : str
            Short environment name as stored on a spec.
        version : str (optional)
            Override the suite's default version suffix.

        Returns
        -------
        canonical : str
            Full environment ID (e.g. `"mjx/cartpole-v0"`).
        """
        return f"{self.prefix}/{name}-{version or self.version}"

    def all_names(self, version: str | None = None) -> List[str]:
        """
        Return canonical IDs for every environment in this suite.

        Parameters
        ----------
        version : str (optional)
            Override the suite's default version suffix.

        Returns
        -------
        names : List[str]
            One canonical ID per spec.
        """
        return [self.get_name(s.name, version) for s in self.specs]

    def __contains__(self, name: str) -> bool:
        """Return `True` if a short name is in this suite's specs."""
        return any(s.name == name for s in self.specs)

    def __getitem__(self, key: Union[int, slice]) -> "EnvSuite":
        """
        Return a new suite containing only the selected spec(s).

        Parameters
        ----------
        key : int | slice
            Index or slice into `self.specs`.

        Returns
        -------
        suite : EnvSuite
            Same class as `self`, with the subset of specs.
        """
        if isinstance(key, int):
            selected = [self.specs[key]]
        else:
            selected = self.specs[key]

        return self.__class__(
            prefix=self.prefix,
            category=self.category,
            version=self.version,
            required_packages=self.required_packages,
            specs=selected,
        )

    def __iter__(self) -> Iterator[str]:
        """Yield canonical ID strings for all environments."""
        for spec in self.specs:
            yield self.get_name(spec.name)

    def __len__(self) -> int:
        """Return number of environments."""
        return len(self.specs)

    def check(self) -> Dict[str, bool]:
        """
        Check whether each required package is importable.

        Returns
        -------
        status : Dict[str, bool]
            Mapping of package name → installed flag.
        """
        return {pkg: find_spec(pkg) is not None for pkg in self.required_packages}

    def is_available(self) -> bool:
        """
        Return `True` if all required packages are installed.

        Returns
        -------
        available : bool
            If required packages are installed.
        """
        return all(self.check().values())

envs property

Short names of all environments in this suite, derived from specs.

n_envs property

Number of environments in this suite.

get_name(name, version=None)

Return the canonical ID string for a single environment.

Parameters:

Name Type Description Default
name str

Short environment name as stored on a spec.

required
version str

Override the suite's default version suffix.

None

Returns:

Name Type Description
canonical str

Full environment ID (e.g. "mjx/cartpole-v0").

Source code in envrax/suite.py
Python
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def get_name(self, name: str, version: str | None = None) -> str:
    """
    Return the canonical ID string for a single environment.

    Parameters
    ----------
    name : str
        Short environment name as stored on a spec.
    version : str (optional)
        Override the suite's default version suffix.

    Returns
    -------
    canonical : str
        Full environment ID (e.g. `"mjx/cartpole-v0"`).
    """
    return f"{self.prefix}/{name}-{version or self.version}"

all_names(version=None)

Return canonical IDs for every environment in this suite.

Parameters:

Name Type Description Default
version str

Override the suite's default version suffix.

None

Returns:

Name Type Description
names List[str]

One canonical ID per spec.

Source code in envrax/suite.py
Python
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def all_names(self, version: str | None = None) -> List[str]:
    """
    Return canonical IDs for every environment in this suite.

    Parameters
    ----------
    version : str (optional)
        Override the suite's default version suffix.

    Returns
    -------
    names : List[str]
        One canonical ID per spec.
    """
    return [self.get_name(s.name, version) for s in self.specs]

__contains__(name)

Return True if a short name is in this suite's specs.

Source code in envrax/suite.py
Python
117
118
119
def __contains__(self, name: str) -> bool:
    """Return `True` if a short name is in this suite's specs."""
    return any(s.name == name for s in self.specs)

__getitem__(key)

Return a new suite containing only the selected spec(s).

Parameters:

Name Type Description Default
key int | slice

Index or slice into self.specs.

required

Returns:

Name Type Description
suite EnvSuite

Same class as self, with the subset of specs.

Source code in envrax/suite.py
Python
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __getitem__(self, key: Union[int, slice]) -> "EnvSuite":
    """
    Return a new suite containing only the selected spec(s).

    Parameters
    ----------
    key : int | slice
        Index or slice into `self.specs`.

    Returns
    -------
    suite : EnvSuite
        Same class as `self`, with the subset of specs.
    """
    if isinstance(key, int):
        selected = [self.specs[key]]
    else:
        selected = self.specs[key]

    return self.__class__(
        prefix=self.prefix,
        category=self.category,
        version=self.version,
        required_packages=self.required_packages,
        specs=selected,
    )

__iter__()

Yield canonical ID strings for all environments.

Source code in envrax/suite.py
Python
148
149
150
151
def __iter__(self) -> Iterator[str]:
    """Yield canonical ID strings for all environments."""
    for spec in self.specs:
        yield self.get_name(spec.name)

__len__()

Return number of environments.

Source code in envrax/suite.py
Python
153
154
155
def __len__(self) -> int:
    """Return number of environments."""
    return len(self.specs)

check()

Check whether each required package is importable.

Returns:

Name Type Description
status Dict[str, bool]

Mapping of package name → installed flag.

Source code in envrax/suite.py
Python
157
158
159
160
161
162
163
164
165
166
def check(self) -> Dict[str, bool]:
    """
    Check whether each required package is importable.

    Returns
    -------
    status : Dict[str, bool]
        Mapping of package name → installed flag.
    """
    return {pkg: find_spec(pkg) is not None for pkg in self.required_packages}

is_available()

Return True if all required packages are installed.

Returns:

Name Type Description
available bool

If required packages are installed.

Source code in envrax/suite.py
Python
168
169
170
171
172
173
174
175
176
177
def is_available(self) -> bool:
    """
    Return `True` if all required packages are installed.

    Returns
    -------
    available : bool
        If required packages are installed.
    """
    return all(self.check().values())

envrax.suite.EnvSet

An ordered collection of EnvSuite instances.

Combines multiple suites into a single iterable that yields canonical environment ID strings. Supports merging two EnvSet objects with +.

Parameters:

Name Type Description Default
*suites EnvSuite

Variable number of environment suites to combine.

required
Source code in envrax/suite.py
Python
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
class EnvSet:
    """
    An ordered collection of `EnvSuite` instances.

    Combines multiple suites into a single iterable that yields canonical
    environment ID strings. Supports merging two `EnvSet` objects with `+`.

    Parameters
    ----------
    *suites : EnvSuite
        Variable number of environment suites to combine.
    """

    def __init__(self, *suites: EnvSuite) -> None:
        self._suites: List[EnvSuite] = list(suites)

    @property
    def n_envs(self) -> int:
        """Total number of environments across all suites."""
        return sum(s.n_envs for s in self._suites)

    @property
    def suites(self) -> List[EnvSuite]:
        """List of environment suites in this set."""
        return self._suites

    def all_names(self, version: str | None = None) -> List[str]:
        """
        Return canonical IDs for every environment across all suites.

        Parameters
        ----------
        version : str (optional)
            Override the default version suffix for all suites.

        Returns
        -------
        names : List[str]
            List of canonical IDs for every environment
        """
        names: List[str] = []
        for suite in self._suites:
            names.extend(suite.all_names(version))
        return names

    def env_categories(self) -> Dict[str, int]:
        """
        Return a mapping of category name → environment count.

        Returns
        -------
        categories : Dict[str, int]
            category name → environment count mapping
        """
        counts: Dict[str, int] = {}
        for s in self._suites:
            counts[s.category] = counts.get(s.category, 0) + s.n_envs
        return counts

    def __iter__(self) -> Iterator[str]:
        """Yield canonical ID strings from all suites in order."""
        for suite in self._suites:
            yield from suite

    def __len__(self) -> int:
        """Total number of environments."""
        return self.n_envs

    def __add__(self, other: Self) -> Self:
        """Merge two EnvSets into a new one."""
        return type(self)(*self._suites, *other._suites)

    @classmethod
    def from_names(cls, names: List[str]) -> Self:
        """
        Build an `EnvSet` from a list of registered canonical IDs.

        Names are looked up via the registry and grouped by their suite
        category tag (`EnvSpec.suite`). Used to reconstruct an `EnvSet`
        from persisted state — e.g. checkpoint metadata — without needing
        the original suite class hierarchy.

        Parameters
        ----------
        names : List[str]
            Registered canonical env IDs (e.g. `["mjx/cartpole-v0", ...]`).

        Returns
        -------
        env_set : EnvSet
            One suite per distinct category, each holding the matching specs.

        Raises
        ------
        unknown_env : ValueError
            Propagated from `get_spec` if any name is not registered.
        """
        from envrax.registry import get_spec  # local import to avoid cycles

        by_cat: Dict[str, List[EnvSpec]] = defaultdict(list)
        for name in names:
            spec = get_spec(name)
            by_cat[spec.suite].append(spec)

        suites = [
            _RegisteredSuite(category=category, specs=specs)
            for category, specs in by_cat.items()
        ]
        return cls(*suites)

    def verify_packages(self) -> None:
        """
        Verify that every suite has its required packages installed.

        Raises
        ------
        error : MissingPackageError
            If any suite has missing required packages.
        """
        missing: Dict[str, List[str]] = {}
        for suite in self._suites:
            status = suite.check()
            not_installed = [pkg for pkg, ok in status.items() if not ok]
            if not_installed:
                missing[suite.category] = not_installed

        if missing:
            lines = [f"  {cat}: {', '.join(pkgs)}" for cat, pkgs in missing.items()]
            raise MissingPackageError(
                "Missing required packages for environment suites:\n" + "\n".join(lines)
            )

    def __repr__(self) -> str:
        suite_info = ", ".join(
            f"{s.__class__.__name__}({s.n_envs})" for s in self._suites
        )
        return f"EnvSet({suite_info}, total={self.n_envs})"

n_envs property

Total number of environments across all suites.

suites property

List of environment suites in this set.

all_names(version=None)

Return canonical IDs for every environment across all suites.

Parameters:

Name Type Description Default
version str

Override the default version suffix for all suites.

None

Returns:

Name Type Description
names List[str]

List of canonical IDs for every environment

Source code in envrax/suite.py
Python
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def all_names(self, version: str | None = None) -> List[str]:
    """
    Return canonical IDs for every environment across all suites.

    Parameters
    ----------
    version : str (optional)
        Override the default version suffix for all suites.

    Returns
    -------
    names : List[str]
        List of canonical IDs for every environment
    """
    names: List[str] = []
    for suite in self._suites:
        names.extend(suite.all_names(version))
    return names

env_categories()

Return a mapping of category name → environment count.

Returns:

Name Type Description
categories Dict[str, int]

category name → environment count mapping

Source code in envrax/suite.py
Python
232
233
234
235
236
237
238
239
240
241
242
243
244
def env_categories(self) -> Dict[str, int]:
    """
    Return a mapping of category name → environment count.

    Returns
    -------
    categories : Dict[str, int]
        category name → environment count mapping
    """
    counts: Dict[str, int] = {}
    for s in self._suites:
        counts[s.category] = counts.get(s.category, 0) + s.n_envs
    return counts

__iter__()

Yield canonical ID strings from all suites in order.

Source code in envrax/suite.py
Python
246
247
248
249
def __iter__(self) -> Iterator[str]:
    """Yield canonical ID strings from all suites in order."""
    for suite in self._suites:
        yield from suite

__len__()

Total number of environments.

Source code in envrax/suite.py
Python
251
252
253
def __len__(self) -> int:
    """Total number of environments."""
    return self.n_envs

__add__(other)

Merge two EnvSets into a new one.

Source code in envrax/suite.py
Python
255
256
257
def __add__(self, other: Self) -> Self:
    """Merge two EnvSets into a new one."""
    return type(self)(*self._suites, *other._suites)

from_names(names) classmethod

Build an EnvSet from a list of registered canonical IDs.

Names are looked up via the registry and grouped by their suite category tag (EnvSpec.suite). Used to reconstruct an EnvSet from persisted state — e.g. checkpoint metadata — without needing the original suite class hierarchy.

Parameters:

Name Type Description Default
names List[str]

Registered canonical env IDs (e.g. ["mjx/cartpole-v0", ...]).

required

Returns:

Name Type Description
env_set EnvSet

One suite per distinct category, each holding the matching specs.

Raises:

Name Type Description
unknown_env ValueError

Propagated from get_spec if any name is not registered.

Source code in envrax/suite.py
Python
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@classmethod
def from_names(cls, names: List[str]) -> Self:
    """
    Build an `EnvSet` from a list of registered canonical IDs.

    Names are looked up via the registry and grouped by their suite
    category tag (`EnvSpec.suite`). Used to reconstruct an `EnvSet`
    from persisted state — e.g. checkpoint metadata — without needing
    the original suite class hierarchy.

    Parameters
    ----------
    names : List[str]
        Registered canonical env IDs (e.g. `["mjx/cartpole-v0", ...]`).

    Returns
    -------
    env_set : EnvSet
        One suite per distinct category, each holding the matching specs.

    Raises
    ------
    unknown_env : ValueError
        Propagated from `get_spec` if any name is not registered.
    """
    from envrax.registry import get_spec  # local import to avoid cycles

    by_cat: Dict[str, List[EnvSpec]] = defaultdict(list)
    for name in names:
        spec = get_spec(name)
        by_cat[spec.suite].append(spec)

    suites = [
        _RegisteredSuite(category=category, specs=specs)
        for category, specs in by_cat.items()
    ]
    return cls(*suites)

verify_packages()

Verify that every suite has its required packages installed.

Raises:

Name Type Description
error MissingPackageError

If any suite has missing required packages.

Source code in envrax/suite.py
Python
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def verify_packages(self) -> None:
    """
    Verify that every suite has its required packages installed.

    Raises
    ------
    error : MissingPackageError
        If any suite has missing required packages.
    """
    missing: Dict[str, List[str]] = {}
    for suite in self._suites:
        status = suite.check()
        not_installed = [pkg for pkg, ok in status.items() if not ok]
        if not_installed:
            missing[suite.category] = not_installed

    if missing:
        lines = [f"  {cat}: {', '.join(pkgs)}" for cat, pkgs in missing.items()]
        raise MissingPackageError(
            "Missing required packages for environment suites:\n" + "\n".join(lines)
        )