Rendering¶
To truly understand something we need to see it in action. Let's say your agent isn't learning. Is that because the agent is broken or the environment isn't working as intended?
The only way to know for sure is to observe how your agent operates inside your environment. This is where rendering comes in.
Rendering turns an environment state (observation) into a snapshot (picture) that allows you to observe what the agent is doing and how it's interacting with the environment at a point in time.
It's completely optional but an invaluable tool to add to your arsenal for effectively debugging your agents and environments, logging, and recording videos.
Throughout this tutorial, we'll discuss how we can apply this to our Envrax environments using a render() method. Let's get started!
render()¶
API Docs
By default, JaxEnv provides a blank render(state) method that must be implemented on your own environments before you can use it.
It's completely optional and isn't a mandatory requirement for environments but is highly valuable to implement. By default, environments that don't override the render() method raise a NotImplementedError.
For example, let's expand the BallEnv we created earlier and give it a simple render() method:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | |
Use NumPy arrays instead of JAX
For compatibility and usability with other packages such as Pillow [], matplotlib [], and Jupyter Notebooks [], the render() method should return a NumPy (not JAX) array of shape (H, W, 3) and datatype of uint8.
We can then run it to produce a single frame:
| Python | |
|---|---|
1 2 3 4 5 6 | |
And save it to PNG or display it with whatever image library you prefer:
| Python | |
|---|---|
1 2 | |
Rendering from a VecEnv¶
VecEnv supports rendering by default with its own render() method. It uses the underlying environment's render() method and assumes that it has already been implemented. If not, it simply falls back to the base case and raises a NotImplementedError.
To use it, you need to provide a state and specify the index of the environment that you want to extract from the batched state PyTree:
| Python | |
|---|---|
1 2 3 4 | |
No need to craft this behaviour by hand!
Rendering Through Wrappers¶
Wrappers forward render() to their inner env by default, so you'll never lose rendering capabilities:
| Python | |
|---|---|
1 2 | |
This applies even to stateful wrappers like FrameStackObservation, because the wrapper state keeps a forwarded copy of the inner state that gets passed down the chain.
Design Tips¶
Here are some design tips to consider when building your render() methods:
- Keep
renderfast but not JIT-compiled.renderis called at human timescales (once per frame for video, once per log step). Don't JIT it! Keep it in NumPy so that any downstream drawing libraries (Pillow, OpenCV, matplotlib) can use it without JAX restrictions. - Avoid dynamic shapes. Always return the same
(H, W, 3)shape. If you need a HUD or overlay, draw it into the fixed canvas rather than concatenating arrays of varying sizes. - Pull scalars once. Calling
float(state.field)forces a device → host transfer. To maximise performance, grab all the values you need at the top ofrenderand then create the renderable frame. - Cache expensive static assets. If your render draws a fixed background (e.g. a maze layout or a level tileset), cache it on
selfin the__init__method and copy it per frame.
Rendering as MP4¶
Once render is implemented, you can wrap your environment in a RecordVideo wrapper.
This automatically captures episode rollouts to MP4 for you:
| Python | |
|---|---|
1 2 3 4 | |
See the RecordVideo section of the wrappers tutorial for the full set of trigger options.
Common Pitfalls¶
Be wary of the following "gotchas":
- Returning
jnp.ndarray. Image libraries expect NumPy, and as part of the API standard, you should always return annp.ndarrayfrom therender()method. - Wrong dtype. Rendered frames must be
uint8, notfloat32. If you normalised them, multiply them by255and cast them touint8before returning the frames. - Trying to render from inside
jax.jit.renderreads JAX scalars Python-side, which triggers aConcretizationTypeErrorunder tracing. Always userender()outside the JIT boundary. NotImplementedErrorwhen wrapped. Oops! You've forgotten to build a customrender()method. Simply override the inheritedrender()method on your environment class.
Recap¶
And there we have it! That's how to use visual rendering in your environments. To recap:
render(state)must return annp.ndarrayuint8 of shape(H, W, 3)render()runs Python-side and should not be used inside JIT so that it can support any CPU drawing library- Wrappers forward
render()by default VecEnv.render(state, index=[int])picks one env from a batch- Implementing a
render()method enables video recording, episode logging, and visual debugging
Essentials Series: Summary¶
Excellent work!
It's official, that was the last of the Essentials tutorials. You now know all the core details needed to use Envrax to your heart's content!
Here's a final recap of what we've covered in this series:
- Environment State: You learned how to model environment data as immutable, JAX-traceable PyTrees by extending
EnvState, and how to thread a PRNG key through the episode so randomness stays deterministic across everyresetandstepcall. - Spaces: You explored how to describe what your agent sees and how it acts using
Box,Discrete, andMultiDiscreteas pure metadata contracts that let downstream code shape policy networks and catch shape mismatches before they hit your training loop. - Environment Configuration: You learned the static vs. traceable split —
EnvConfigfor one-time settings declared at construction,EnvStatefor everything that changes through an episode. Keeping them separate is what stops JIT from silently re-compiling on you. - Your First Environment: You built an end-to-end
JaxEnvfrom scratch by pinning the four type generics, defining your spaces, and implementing thereset/stepcontract that drives every Envrax environment. - Vectorising with
VecEnv: You learned how to spin up hundreds of parallel environments in a single line viajax.vmap, with auto-reset ondone=Truehandled inside the vmapped body so you never have to branch on episode boundaries yourself. - Multiple Environments: You've seen how to compose
Mheterogeneous environments into a single managed fleet usingMultiEnv/MultiVecEnv, making multi-task training, meta-learning, and heterogeneous evaluation suites that tiny bit easier. - Environment Registry: You learned how to expose environments under canonical names through
register()/register_suite()so your training, evaluation, and analysis scripts all share one source of truth instead of importing classes directly. - Make Methods: You learned how to use Envrax's single-line builders —
make(),make_vec(),make_multi(), andmake_multi_vec()— which handle wrappers, JIT compilation, and resolved configs automatically for you. - Available Wrappers: You learned the difference between pass-through and stateful wrappers, when to reach for each, and how to chain them together to build classic preprocessing pipelines like the Atari image stack.
- Rendering: And with this tutorial, you learned how to turn any environment state into a viewable RGB frame with
render(state)and feed it straight into video recording, training-time visualisations, or quick debugging sessions.
Next Steps¶
So, where to next?
You've now got enough knowledge to use the full extent of the Envrax package. At this point, you should really start building!
But if you need some extra help on unique topics we highly recommend checking out our Advanced tutorials. Alternatively, you can also check out our API reference to look up specific classes and methods.
Happy building!
-
Advanced Recipes
Task-focused walkthroughs for more advanced topics and customization options.
-
API Reference
Look up specific classes, methods, and parameter signatures that Envrax uses.