Creating a Custom Space¶
In our Spaces tutorial we discussed how Envrax ships with three space types: Discrete, Box, and MultiDiscrete. These are the simplest variants and cover the most common RL environment use cases, but sometimes you need something a lot more unique.
Maybe you want a one-hot encoded categorical space, a bitstring, a truncated normal distribution, or a weighted distribution. None of those are obvious fits for the current built-ins spaces. So, you'll need to build your own!
In this tutorial, we'll walk through how to do exactly that. Let's get into it!
Space Requirements¶
API Docs
Every space must inherit from the envrax.spaces.Space base class and implement three methods:
| Method | Purpose | Returns |
|---|---|---|
sample(rng) |
Samples a random action from the space | chex.Array |
contains(x) |
Checks if x is a valid item in the space |
bool |
batch(n) |
Returns a space with a leading batch dimension n |
Space |
We also recommend making the custom space class a frozen dataclass using @dataclasses.dataclass(frozen=True) to make it immutable metadata. Its a useful practice to help avoid accidental changes to something that should be a static entity. Envrax also does this with its own built-in spaces.
Okay, so that's the basics of Space requirements. Let's now build a custom one to get a better feel for it.
Working Example: OneHot(n)¶
Full Code
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | |
For this tutorial, we'll build a one-hot encoded categorical action space.
For those unfamiliar with the concept, we take a standard set of categorical options and turn each one into a vector of 0s with a single 1 at its respective index.
Let's say we have 3 actions: up, down, and stay. Each one represents a different category. As a vector, they would look like this:
| up | down | stay |
|---|---|---|
| 1 | 0 | 0 |
| 0 | 1 | 0 |
| 0 | 0 | 1 |
That's it! That's what our space would hold. For now, we'll call the number of our categories n and think about how we can implement this later.
Next, we need to consider how we want to support VecEnv compatibility. For the batch(n) method we need a way to define and store the batch dimension.
With Envrax's Discrete space, we turned it into a MultiDiscrete space. That's one way of doing it (turning the space into another one entirely), but it feels a bit overkill here. Instead, we can just track the leading shape dimension for our batch as a Tuple[int, ...]. We'll call this our batch_shape. Again, we'll get into this in a bit more detail shortly.
That's our two parameters sorted. The next question is: what can this space actually be useful for? As an action space, downstream networks (policy heads, critics, value functions) may expect fixed-size vectors. With this new space, we can skip the manual jax.nn.one_hot() conversion that a regular Discrete action space would otherwise need. It's a small difference, but it can go a long way.
We can also push it a step further. Envrax's Discrete space only supports uniform sampling, but in many RL setups (curriculum learning, biased exploration, weighted task selection) you may want some categories to be picked more often than others. We'll add an optional probs parameter that lets us supply per-category sampling probabilities to handle this.
Great! That gives us three parameters to work with as we build our OneHot space. We'll break its development into four key steps:
- Defining the class skeleton
- Implementing
sample() - Implementing
contains() - Implementing
batch()
Step 1: Class Skeleton¶
First up, let's put together the Space dataclass.
We've already touched on the parameters briefly (n, batch_shape, and probs), but there's one more we need to consider: the dtype of the space.
Based on the table we've seen, you might be thinking that an integer dtype (e.g., uint8 or int32) is the right fit here. It's a valid choice, but for better compatibility with neural networks, we'd recommend using float32. It's the standard for most deep learning workflows and the precision is good enough without needing a dtype conversion downstream.
With that in mind, let's put our Space together and document it with a suitable docstring:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | |
Step 2: Implement sample()¶
Next, we'll tackle the sample() method. This is pretty simple. All we need to do is create the logic for randomly sampling a one-hot encoded vector.
We can use the jax.random module for this. For the probs case, we can use jax.random.choice, otherwise we fall back to uniform sampling via jax.random.randint:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | |
There are a few things to note here:
self.probs is not Noneis checked Python-side (at sample-time, not under JIT trace), so this branch is fine — the resulting traced graph is one or the other, not both.jax.nn.one_hot(idx, n)broadcasts cleanly over any shape — a scalaridxproduces a vector of lengthn, and a(k,)idxproduces a(k, n)matrix. That's exactly what we want for batching with no branching needed!- We convert
probsfrom a Python tuple to ajnp.arrayinsidesample()rather than at__init__time. This keeps the dataclass itself hashable (tuples are; arrays aren't) without losing JAX compatibility at sample time.
Step 3: Implement contains()¶
Now for the membership check. This is a little more extensive.
We need to verify that the input matches the expected shape, has exactly one 1 per row, and that all other values are 0:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | |
This performs three checks: shape, binary-ness, and exactly-one-per-row. If any of them fail, we return False.
Step 4: Implement batch()¶
Lastly, we prepend n to batch_shape and carry the other parameters through unchanged:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | |
Easy enough! Since we have a batch_shape parameter there's no need to do any crazy space type conversion.
Notice that probs is also forwarded as-is - every batched sample uses the same per-category distribution. If you wanted the distribution itself to vary across batch elements, you'd need a different design (e.g., storing probs with a leading batch dim too), but for the typical "same env, multiple parallel copies" case, sharing the distribution is what you want.
Running It¶
With all methods in place, we can do a quick dummy run to confirm everything is working correctly:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 | |
And there we have it! A new space created and ready for use!
Using It on a JaxEnv¶
Custom spaces plug into any JaxEnv subclass the same way the built-ins do. Since we built OneHot as an action space, it slots into the second generic position (and the second @property):
| Python | |
|---|---|
1 2 3 4 5 6 7 8 | |
And it slots into VecEnv just as cleanly:
| Python | |
|---|---|
1 2 3 | |
Common Pitfalls¶
Be wary of the following "gotchas":
- Forgetting
@dataclass(frozen=True). Withoutfrozen=True, your space becomes mutable and unhashable — which breaks equality checks (OneHot(n=4) == OneHot(n=4)would compare by identity instead of by value), prevents the space from being used as adictkey orsetmember, and silently corrupts any code that caches spaces by hash. Every Envrax built-in isfrozen=True; your custom spaces should match. - Not threading
dtype. We deliberately exposeddtypeas a parameter onOneHotso users could pickint32/uint8/float16for memory savings or downstream-network compatibility. If you hard-codejnp.float32insidesample()instead of usingself.dtype, users lose that flexibility — your space silently ignores theirdtype=...argument. - Using
jax.random.PRNGKeyinstead ofchex.Array. Functionally identical at runtime, butchex.Arrayis the convention used across every method signature in Envrax (Space.sample,JaxEnv.reset, every wrapper). Sticking to it keeps your custom space consistent with the API standard so type checkers and IDE hovers behave the same way as the built-ins. - Leaking Python-side computation into
sample(). The method runs inside JAX traces (jax.jit,jax.vmap,jax.lax.scan), so anyiforforthat branches on a traced value will raiseConcretizationTypeError. Theself.probs is not Nonecheck in oursample()method is fine becauseself.probsis a Python attribute on the dataclass — it resolves before tracing kicks in, so JAX only ever sees one branch baked into the compiled graph. - Using a
jnp.arrayinstead of aTupleforprobs-style fields. We choseTuple[float, ...]forprobsrather thanjnp.array(...)for a specific reason:frozen=Truedataclasses need their fields to be hashable, and JAX arrays aren't (they're mutable buffers underneath). Storingprobsas a tuple keeps the dataclass valid, and the inlinejnp.array(self.probs)conversion insidesample()is the only place we pay the cost — once per call, not stored. - Forgetting to forward optional parameters in
batch(). Ourbatch()explicitly passesdtype=self.dtypeandprobs=self.probsthrough to the new instance. If you skip those, the batched copy silently falls back to the defaults — so aOneHot(n=3, probs=(0.7, 0.2, 0.1))would suddenly become uniform afterbatch(64), and your weighted sampling stops working withVecEnv. Always forward every instance field.
Recap¶
Excellent work! You've just built your first custom Envrax space from scratch!
Let's recap what we've covered and discuss some key points to consider when building your own custom spaces:
- Custom
Spacerequirements — every custom space must inherit fromenvrax.spaces.Space, use@dataclass(frozen=True)for immutable metadata, and implement three abstract methods:sample(rng),contains(x), andbatch(n). - Designing your parameters — pick a single required field that defines the space's identity (
nfor us, but it could bedimfor a simplex,low/highfor a truncated normal, etc.), then layer on optional knobs (dtype,batch_shape, distribution shape) with sensible defaults. Frozen dataclass fields make this declarative and immutable for free. - Writing
sample()— keep it pure JAX so it composes withjax.jit,jax.vmap, andjax.lax.scan. Branching on Python-side attributes (like ourif self.probs is not None) is safe because it resolves before tracing kicks in. This is handy when one space needs to support multiple sampling regimes. - Writing
contains()— combine the structural checks your space requires (shape, dtype, value bounds, invariants) into a singlebool. Bail out early on cheap shape mismatches before doing the more expensive elementwise checks to keep things fast. - Picking a
batch()strategy — we recommend one of two patterns: stay within your own type by trackingbatch_shapeon the instance (clean when the "element" stays the same shape, likeOneHot), or switch to a different space type when the batched semantics warrant it (the wayDiscrete → MultiDiscretedoes). - Storing JAX-incompatible config on a frozen dataclass — frozen dataclasses need every field to be hashable. The best approach is to store the values as a Python
Tuple(or other immutable container) on the instance, then convert it to ajnp.arrayinsidesample()at call time. This works for probability vectors (like our example), weight tables, level grids, and anything else you can think of. - Forwarding state through
batch()— every field you add to the instance has to be explicitly carried through to the new batched copy. Skip even one and the configuration silently disappears the momentVecEnvwraps your environment.