Environment Configuration¶
Before we wire state and spaces into a working JaxEnv, there's one more piece to introduce: how to handle the environment's core/unique properties. This is where EnvConfig comes in.
Every JaxEnv holds one under the self.config property. Here's what the base looks like:
| Python | |
|---|---|
1 2 3 | |
One field that defines the default episode length of the environment. Nice and simple!
EnvConfig is designed as static data that is set once at construction and never changed through the episode. If your environment has gravity, reward scales, difficulty modes, or level seeds; this is where to put them.
EnvConfig vs. EnvState¶
Now you may be wondering: "Can't everything just live in EnvState?" While technically true, the key distinction is in how JAX handles static vs. traceable data. As we mentioned in our earlier tutorials, we need to be careful not to mix static data with traceable data.
Traceable values act as runtime data, allowing them to be changed during each function call without triggering a JIT-compile. Static values, on the other hand, need to be re-traced and re-compiled whenever they change.
As a rule of thumb:
- If an item is fixed for the whole episode, needs to be known at construction time, is a Python scalar or has a static shape, it goes in
EnvConfig. - If an item changes during the episode, only needs to be known at runtime, is a JAX array or JAX compatible, it goes in
EnvState.
Remember: you should only ever need to set the config once at environment creation. Otherwise, JIT will silently use the old cached values and break your training loop without warning.
Extending EnvConfig¶
API Docs
Now let's look at how we can extend EnvConfig. Just like EnvState, we use the @chex.dataclass decorator and subclass from the parent class (EnvConfig). Then, just add the fields we want:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 | |
- Included for free! Simply uncomment it and change
1000to increase the default episode length of this environment
Pinning the Config to Your Environment¶
When you build a JaxEnv subclass, you'll need to pin your custom config (BallConfig) as the fourth generic parameter:
| Python | |
|---|---|
1 2 | |
We'll discuss this in more detail in the next tutorial (Your First Environment).
Recap¶
To recap:
EnvConfigholds static per-env data;EnvStateholds dynamic per-episode data- Extend
EnvConfigwith@chex.dataclassand add fields with Python scalar defaults - Pin your custom config as the 4th generic parameter on
JaxEnv
Next Steps¶
You've now seen all three foundational pieces — state, spaces, and config. Time to wire them into a working environment!
-
Your First Environment
Subclass
JaxEnv, implementresetandstep, and useBallConfigto drive the dynamics.