Stateful Wrappers¶
Wrappers that carry their own unique state across steps.
envrax.wrappers.frame_stack.FrameStackObservation
¶
Bases: StatefulWrapper[Box, ActSpaceT, FrameStackState[InnerStateT], ConfigT, InnerStateT]
Maintain a sliding window of the last n_stack observations.
Expects the inner environment to produce uint8[H, W] observations.
The stacked observation has shape uint8[H, W, n_stack].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Inner environment returning 2-D |
required |
n_stack
|
int
|
Number of frames to stack. Default is |
required |
Source code in envrax/wrappers/frame_stack.py
| Python | |
|---|---|
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | |
reset(rng)
¶
Reset the inner environment and initialise the frame stack.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Initial stacked observation |
state |
FrameStackState
|
Wrapper state containing the inner state and the stack |
Source code in envrax/wrappers/frame_stack.py
| Python | |
|---|---|
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | |
step(state, action)
¶
Step the inner environment and roll the frame stack.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
FrameStackState
|
Current wrapper state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Updated stacked observation |
new_state |
FrameStackState
|
Updated wrapper state |
reward |
Array
|
Reward from the inner step |
done |
Array
|
Terminal flag from the inner step |
info |
Dict[str, Any]
|
Info dict from the inner step |
Source code in envrax/wrappers/frame_stack.py
| Python | |
|---|---|
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | |
envrax.wrappers.record_episode_statistics.RecordEpisodeStatistics
¶
Bases: StatefulWrapper[ObsSpaceT, ActSpaceT, EpisodeStatisticsState[InnerStateT], ConfigT, InnerStateT]
Records episode return and length.
Accumulates reward and step count in EpisodeStatisticsState.
Episode statistics are written to info["episode"] on every step()
call; values are non-zero only when done=True.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
JaxEnv
|
Environment to wrap. |
required |
Source code in envrax/wrappers/record_episode_statistics.py
| Python | |
|---|---|
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | |
reset(rng)
¶
Reset the environment and episode accumulators.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX PRNG key |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Initial observation |
state |
EpisodeStatisticsState
|
Initial state with zeroed accumulators |
Source code in envrax/wrappers/record_episode_statistics.py
| Python | |
|---|---|
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | |
step(state, action)
¶
Step the environment and update episode accumulators.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
EpisodeStatisticsState
|
Current state |
required |
action
|
Array
|
Action to take in the environment |
required |
Returns:
| Name | Type | Description |
|---|---|---|
obs |
Array
|
Next observation |
new_state |
EpisodeStatisticsState
|
Updated state |
reward |
Array
|
Step reward |
done |
Array
|
Episode terminal flag |
info |
Dict[str, Any]
|
Environment metadata extended with |
Source code in envrax/wrappers/record_episode_statistics.py
| Python | |
|---|---|
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | |