Your First Environment¶
API Docs
Welcome back! So far, you've developed your understanding of the three foundational pieces for building Envrax environments:
- State — the immutable snapshot of your environment
- Spaces — the contracts describing observations and actions
- Configuration — the static settings that drive its dynamics
Now, it's time to wire them together into a working environment!
We'll build a tiny 2D ball world where a ball starts at a random location and an agent takes one of four discrete actions per step: [left, right, up, down]. By the end you'll have a runnable JaxEnv and understand the reset / step contract that every Envrax environment follows.
Without further ado, let's get to it!
Fundamental Components¶
From our first tutorial (State) we already created our BallState, here's a refresher:
| Python | |
|---|---|
1 2 3 4 | |
We'll also reuse the BallConfig from Configuration:
| Python | |
|---|---|
1 2 3 4 | |
What we didn't discuss was the types of Spaces we were going to use. Recall that we need two: an observation space and an action space.
Based on our initial brief, the action space is easy - Discrete(n=4) to cover our four movement options.
However, the observation space is a little trickier. To help with this, let's consider the following:
- What does the agent see? - we want the agent to be able to see how the ball moves towards a target, so we'll need it to monitor its
(x, y)position. - What format are the positions in? - could we set them up as a
Discretespace or aBox? Here we usefloatvalues soBoxwould come naturally. - What value range do we need? - this defines the "world" the ball lives in and is purely a design choice. We could use absolute coordinates (e.g., pixel positions with a range of
[0, 800]), but this is unnecessary complexity. Instead, we can use a normalized range between0.0and1.0. This is a natural fit for Neural Networks too!
That gives us Box(low=0.0, high=1.0, shape=(2,), dtype=jnp.float32) - a continuous 2-vector bounded observation space with values between 0 and 1, matching the jnp.float32 dtype we used on BallState.
Perfect! Now we have everything we need. Let's build our BallEnv!
Building the Environment¶
Full Code
If needed, here's the full code used throughout this tutorial. Drop it into a file called ball_env.py and run it:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
This code should work "as is".
We can build an Envrax environment in three easy steps:
- Choosing a class name and assigning the generic types
- Defining the environment's spaces
- Implementing the methods -
resetandstep
Step 1: Declaring our Class¶
JaxEnv Base Class
Curious what's under the hood? Here's JaxEnv stripped to its essentials:
JaxEnv Code
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
Two things worth mentioning:
ABC— marks the class as abstract, forcing subclasses to implement every method marked with@abstractmethodbefore they can be instantiated.Generic[ObsSpaceT, ActSpaceT, StateT, ConfigT]— declares four type parameters, eachboundto its base type (Space,EnvState, orEnvConfig). So, when you writeJaxEnv[Box, Discrete, BallState, BallConfig], you're pinning those TypeVars to concrete types for this subclass. This allows your IDE to know which type is being used and can perform autocompletion correctly without hacky overrides or# type: ignore.
Every Envrax environment must subclass JaxEnv and pin four data types for IDE support. These are (in order): the observation space, action space, the environment state, and the environment config.
In our case, we have Box, Discrete and our custom BallState:
| Python | |
|---|---|
1 2 3 4 5 | |
- Format:
[observation_space, action_space, EnvState, EnvConfig]
Step 2: Defining our Spaces¶
Next, we declare the observation_space and action_space as properties on the class:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 | |
Step 3: Implement our Methods¶
Before writing any code, let's first consider what the reset and step methods actually do:
reset- takes ajax.random.key()and outputs an initial observation and an initialEnvState.step- takes the currentEnvStateand an agent'sactionand iterates through the environment to transition to a new observation, produces a newEnvState, provides a reward, a termination result defining whether the environment has ended, and additional metadata.
Reset Method¶
reset is the easier of the two, so we'll start there. Looking at our description, we can unpack it into three key steps:
- Handling the PRNG key
- Creating the initial state
- Creating the initial observation
For the PRNG key, we split it once into two keys - the first for the BallState and the second for splitting again to create the ball's random starting position (the x and y positions).
Here's what the first part looks like:
| Python | |
|---|---|
1 2 3 4 | |
- The
BallStatekey and thepositionsplitting key - The
xandyRNG keys
Now, we can create the initial state with starting values using the random keys:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 | |
Since we are using jnp.float32 values for the ball's (x, y) position, we sample from a uniform [] distribution to get a random starting state that's different with every key.
Finally, we can create the initial observation using the initial positions and return the required values:
| Python | |
|---|---|
1 2 3 | |
Great! That's the reset method done!
Here's what it looks like in full:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 | |
Step Method¶
Now, the step method. Recall that:
step- takes the currentEnvStateand an agent'sactionand iterates through the environment to transition to a new observation, produces a newEnvState, provides a reward, a termination result defining whether the environment has ended, and additional metadata.
Yikes! There's a lot to unpack there so let's think about this carefully. We need to:
- Manage the PRNG randomness to get a new observation and state (required for JAX)
- Create a new
EnvState - Take an action through the environment to create a new observation
- Get a reward signal
- Check if the environment is done
- Get the metadata for the environment step
- Return the required values
That's a lot! Let's take it one step at a time, starting with the PRNG management. For this, we want to extract the rng key from the provided state and split it for reuse on the next timestep.
We can do this in one line using our handy-dandy jax.random.split() approach:
| Python | |
|---|---|
1 2 3 | |
Easy enough! Next, let's create the new EnvState and observation.
Here, we'll create static lookup tables for x and y and extract the corresponding value based on action value as our index. For example, if action=0, x=-0.01, y=0.0.
Then, we'll use jnp.clip to increment our ball state while keeping its values in the bounds of the observation space:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 | |
Notice how we've used the self.config.friction config field here (BallConfig.friction). To give that real ball feel, every per-step displacement is scaled by friction. If we reduce it to 0.5, the ball will move more sluggishly, but if we bump it up to 1.0, it moves at full speed.
If we wanted, we could move this out into a separate _act() method on the environment class to keep our step() method easy to read. We won't do that here for this simple tutorial, but something to think about when building more complex ones!
Now, we use the .replace() method to update the EnvState and create the observation just like the initial one but with our new_state instead:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 | |
Notice how we incremented our step here so that we can track things accordingly. Okay, 3/7 down! Next, the reward signal, our done flag and the metadata.
For this example, we'll give our agent a flat 1.0 per step, scaled by self.config.reward_scale from our BallConfig. Reward function creation and reward shaping is a beast of its own that is out of the scope of this tutorial series. Google DeepMind provide a great post about Specification Gaming [] that highlights some of the challenges when building reward functions. Highly recommend considering it when building your own!
For our termination flag, we'll simply check to see if the current step matches the config.max_steps for our BallConfig (inherited from EnvConfig).
For our metadata we'll just return a Python Dict with the current step count.
Here's what all of that looks like:
| Python | |
|---|---|
1 2 3 4 5 | |
Customization
These three values (reward, done, info) can be far more complicated and customized depending on your environment's complexity.
It's not uncommon to extract these into their own full-blown helper methods such as _reward(), _done(), and _info, just like an _act() function. In fact, it's a good practice to do so!
Remember to check out Envrax's built-in Wrappers to find some existing customization options too!
Lastly, all we need to do is return the values. Here's the complete method with the return statement included:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | |
Notice how we return an updated copy of our new_state with the updated done flag here to simplify our method a little more.
Running It¶
Nice work so far! Now let's try running this bad boy.
We can do that in 3 lines of code + a few print() statements for verification:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | |
That's it! The full reset → step loop!
Using Wrappers¶
Envrax ships with a set of wrappers that transform observations, rewards, or termination flags without touching your env's code. They're applied like onion layers - each takes an inner env and returns a new one with the same reset/step interface but with added functionality (where appropriate):
| Python | |
|---|---|
1 2 3 4 5 | |
For production setups, the make() factory method is useful for doing this automatically:
| Python | |
|---|---|
1 2 3 4 5 6 | |
We'll dive into specific wrappers in Available Wrappers and walk through every factory method in the Make Methods tutorial. For now, just know they exist!
Recap¶
Excellent job! You've just built your first JaxEnv environment!
Here's a quick recap of what we've covered:
- Declared a
BallEnvclass subclassingJaxEnv[Box, Discrete, BallState, BallConfig] - Defined the
observation_spaceandaction_spaceas properties on the class - Implemented the
resetandstepmethods to drive the environment's transitions - Tested it by running the
reset → steploop
Next, we'll explore the EnvConfig and how to customize it.
Next Steps¶
-
Vectorising with
VecEnv
Learn how to run
Nparallel copies of your environments.