Environment State¶
Welcome to your first Envrax tutorial!
Before we build anything, we first need to understand two foundational concepts: the environment's state and environment spaces. In this tutorial, we'll focus on its state.
What is a State?¶
In its simplest form, an environment state is a single snapshot of the current internal representation of the environment that provides a full description of the world.
This is distinct from two related concepts often used in the RL setting: observations and dynamics. Here's how:
- Observations - are a subset/transformation of the state, limiting what the agent gets to see each step.
- Dynamics - are the rules of the environment that compute the next state from the current one.
For a ball moving in 2D, the environment state might contain:
- The ball's current position
- Its velocity
- How many steps have passed
- Whether the episode has ended
If we know the state of the environment, we can then compute the next state given an action, following the Markov property that RL algorithms use.
The Base State¶
API Docs
By design, Envrax represents state as a @chex.dataclass [] — an immutable Python object that JAX treats as a "PyTree" []. This allows us to work with the JAX package without any issues and enables jax.vmap with thousands of environments at once.
But really, why @chex.dataclass?
As mentioned, it registers your class as a JAX PyTree, which gives you four things for free:
- Automatic traversal by
jax.jit,jax.vmap,jax.lax.scan, etc. .replace(...)for immutable updates- Batching —
VecEnvcan stackNstates into a single PyTree with a leading batch dimension - Testing helpers — works out of the box with
chex's assertion utilities (chex.assert_tree_all_close,chex.assert_shape, etc.) for verifying state transitions in unit tests
Plain @dataclasses won't work because they are not PyTrees so JAX can't trace them!
Every Envrax state must inherit from envrax.EnvState. By default, it provides three mandatory fields essential to all environments:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 | |
jax.random.key()threaded through the episode- The current environment timestep (
int32) - Environment termination flag (
bool)
Using class inheritance, you can extend it with whatever your environment needs and keep those fields for free!
Sticking to our 2D ball example, we could add its current x and y position like so:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 | |
Notice how we don't use the Python float type here. There's a reason for that and we'll explain that shortly.
But now, whenever we use BallState, we have access to all five fields: rng, step, done, ball_x and ball_y.
We'll use this BallState throughout the next couple of tutorials, so make sure you get familiar with it!
Field Rules/Types¶
Chex Arrays
chex.Array is a type alias for JAX and NumPy arrays making it a convenient annotation for "any array-like field". It doesn't wrap or modify values at runtime; it just makes type hints more readable.
For consistency, and convenience, we use them throughout the tutorials anywhere a field holds an array.
Fields on an EnvState subclass must be JAX-compatible and traceable. This means we can have either:
- JAX arrays (
jnp.float32,jnp.int32,jnp.bool_,jax.ndarray,chex.Array, etc.) - Nested
@chex.dataclassinstances - Python
list,dict,tuple - Python objects, strings,
None - Python
int,float,bool
Traceable values are really important for the flow of JAX JIT-compiled functions. They act as runtime data, allowing them to be changed during each function call without triggering a re-compile.
JIT-compiling can take a lot of time depending on the size of the computation graph so we really only want 1 "setup" compile at the start of using an environment to help us drastically reduce wall-clock time.
We cannot use Python types like int, float and bool because they are static values. Every time they change, they need to be re-traced and re-compiled. These are great for EnvConfig instead - more on them in a later tutorial!
If you need a fixed-size collection in your EnvState, use a JAX array:
| Python | |
|---|---|
1 2 3 4 | |
Array Shapes
Array sizes must be a fixed shape. If your logical length varies, pad to a max and track a valid-length scalar.
You want to avoid re-compiling whenever possible!
Updating State¶
Since PyTrees are immutable, we have to use the built-in chex.dataclass method .replace(...) whenever we want to make state adjustments.
This returns a new state with the requested fields changed and copies the other fields over automatically:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 | |
Keeping to the JAX theme, our state transitions stay "pure" for JIT compatibility, without making you rebuild every field by hand.
Threading the PRNG Key¶
RL environments need randomness for: random starting positions, noisy transitions, and stochastic rewards.
JAX handles randomness through PRNG keys []. These are explicit values that you split before consuming, rather than needing a hidden global state stored within your environment. We make these keys with the jax.random.key() [] method.
Envrax threads a key through the episode by storing it on the state. The pattern is always the same:
| Python | |
|---|---|
1 2 3 | |
One half of the split (rng) goes back on the state for the next call, while the other half (sub) is consumed now for this step's randomness.
Never Reuse a Key!
Reusing the same rng twice gives you the same sample twice. This is a common source of silent determinism bugs.
Always split your keys before consuming them!
We'll explore this in more detail when we put this into a real step method in the "Your First Environment" tutorial. For now, just remember the split-then-consume pattern.
Nested States¶
Sometimes environment logic might need to remember something between steps (e.g., a rolling buffer of frames or a running reward total). Rather than mutating the inner state, we can wrap it in a larger state with its own extra fields.
The inner state stays untouched and we can still read its information whenever we need it.
A common pattern for this is stacked Wrappers. If you apply a wrapper on top of an environment, the wrapper needs to be able to read the environment's base fields (rng, step, done).
The pattern is similar to what we've discussed previously that uses the @chex.dataclass decorator, but we now have an env_state field that provides us access to an "inner" EnvState:
| Python | |
|---|---|
1 2 3 4 | |
That's it! Everything the inner environment provided is preserved, plus whatever its wrapper needs to remember. This is a more advanced topic so we'll build on this in a later tutorial.
For those curious, you can check it out in the Creating a Custom Wrapper tutorial.
Common Pitfalls¶
When building your custom EnvStates, there are a few common "gotchas" to be mindful of:
AttributeError: 'BallState' object has no attribute 'replace'— you forgot to add the@chex.dataclass.TypeError: Argument ... is not a valid JAX type— a field is a Python object orNonevalue. Convert it to a JAX array.- Silent determinism bugs —
resetwas called twice with the same key and produced the "same episode" that you expected to be different. Make sure you are usingjax.random.split[] per environment. - Shape mismatches under
vmap— a field has a Pythonintinstead of a JAX scalar. Wrap withjnp.int32(...).
Recap¶
And that covers the basics of EnvState! Great job getting this far!
To recap:
EnvStateis a full snapshot of the environment at one timestepEnvStatefields must be both JAX-compatible and traceable — usejnp.*types, not Pythonint/float/bool- In Envrax,
EnvStateis immutable because JAX needs pure functions, so we.replace(...)rather than mutate - We extend
EnvStatewith@chex.dataclassand JAX-compatible fields - We thread the PRNG key through the episode by splitting it each step
- We can nest one
EnvStateinside another when wrappers need extra state of their own
Next Steps¶
Next, we'll look at the second foundational concept of Envrax - spaces!
-
Spaces
Learn how to describe observations and actions with
Spaces.