import chex
import jax
import jax.numpy as jnp
from envrax import JaxEnv, EnvState, EnvConfig
from envrax.spaces import Box, Discrete
@chex.dataclass
class BallState(EnvState):
ball_x: chex.Array
ball_y: chex.Array
@chex.dataclass
class BallConfig(EnvConfig):
friction: float = 0.98
reward_scale: float = 1.0
class BallEnv(JaxEnv[Box, Discrete, BallState, BallConfig]):
@property
def observation_space(self) -> Box:
return Box(low=0.0, high=1.0, shape=(2,), dtype=jnp.float32)
@property
def action_space(self) -> Discrete:
return Discrete(n=4)
def reset(self, rng: chex.PRNGKey):
rng, init_rng = jax.random.split(rng)
rng_x, rng_y = jax.random.split(init_rng)
state = BallState(
rng=rng,
step=jnp.int32(0),
done=jnp.bool_(False),
ball_x=jax.random.uniform(rng_x),
ball_y=jax.random.uniform(rng_y),
)
obs = jnp.array([state.ball_x, state.ball_y])
return obs, state
def step(self, state: BallState, action: chex.Array):
rng, _ = jax.random.split(state.rng)
# Use action to get new obs
# action: 0=left, 1=right, 2=up, 3=down
dx = jnp.array([-0.01, 0.01, 0.0, 0.0])[action] * self.config.friction
dy = jnp.array([0.0, 0.0, -0.01, 0.01])[action] * self.config.friction
# Get bounds
low, high = self.observation_space.low, self.observation_space.high
# Increment obs
new_x = jnp.clip(state.ball_x + dx, low, high)
new_y = jnp.clip(state.ball_y + dy, low, high)
# Update new state
new_state = state.replace(
rng=rng,
step=state.step + 1,
ball_x=new_x,
ball_y=new_y,
)
# Set new obs
obs = jnp.array([new_state.ball_x, new_state.ball_y])
# Compute reward, done, and info
reward = jnp.float32(1.0) * self.config.reward_scale
done = new_state.step >= self.config.max_steps
info = {"current_step": new_state.step}
return obs, new_state.replace(done=done), reward, done, info
if __name__ == "__main__":
# Init the environment
env = BallEnv()
# Set its initial state
key = jax.random.key(42)
obs, state = env.reset(key)
# Iterate through 1000 timesteps
for _ in range(1000):
action = env.action_space.sample(key)
obs, state, reward, done, info = env.step(state, action)
# If episode has ended, reset to start a new one
if done:
new_key, key = jax.random.split(key)
obs, state = env.reset(new_key)