Home
Envrax, a Gymnasium-style API standard for Reinforcement Learning environment creation in JAX.
Envrax is a lightweight open-source JAX-native Reinforcement Learning (RL) environment API standard for single-agents, equivalent to the Gymnasium package. It includes: base classes, spaces, wrappers, and a shared registry for building and utilizing RL environments with ease.
All environment logic follows a stateless functional design that builds on top of the JAX and Chex packages to benefit from JAX accelerator efficiency.
Why Envrax?¶
One of the downsides of RL research is sample efficiency. Often the environment becomes the main bottleneck for model training because it's restricted, and built, around CPU utilisation.
For example, the Atari suite is CPU constrained and, from our experience, when we increase the number of environments running in parallel, a single training step drastically increases wall-clock time. Gradient computations on a GPU could take ~30 seconds but the sample retrieval takes over 2+ minutes (400% increase) because of the CPU bottleneck and that's with efficiency tricks!
This begged a much deeper question -
what if we could eliminate the CPU bottleneck by loading the environment onto the same accelerator as the model?
Packages like Brax and Gymnax have shown the incredible benefits of JAX based environment approaches. However, they are limited to their unique approaches without a unified API standard. Gymnasium has always been a personal favourite of mine because of its API simplicity, but there is no JAX equivalent. Thus, Envrax was born.
-
Getting Started
What are you waiting for?!
-
Open Source, MIT
Envrax is licensed under the MIT License.
