You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[BUG] Incompatible with gymnasium's NormalizeObservation wrapper due to missing "num_envs", "is_vector_env" and "single_observation_space" attributes
#258
Open
3 tasks done
CloudyDory opened this issue
Apr 6, 2023
· 3 comments
It seems that envpool's vectorized environment is not compatible with gymnasium's NormalizeObservation wrapper due to missing "num_envs", "is_vector_env" and "single_observation_space" attributes in the environments returned by envpool.
Here is the code for gymnasium's normalization wrapper:
class NormalizeObservation(gym.Wrapper):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env:
self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
def step(self, action):
"""Steps through the environment and normalizes the observation."""
obs, rews, terminateds, truncateds, infos = self.env.step(action)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs, rews, terminateds, truncateds, infos
def reset(self, **kwargs):
"""Resets the environment and normalizes the observation."""
obs, info = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations."""
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
The init() function needs the above three attributes to work correctly, but the environment object returned by envpool does not have them, causing the init() function to use default values, which are not correct.
Describe the bug
It seems that envpool's vectorized environment is not compatible with gymnasium's NormalizeObservation wrapper due to missing "num_envs", "is_vector_env" and "single_observation_space" attributes in the environments returned by envpool.
Here is the code for gymnasium's normalization wrapper:
The
init()
function needs the above three attributes to work correctly, but the environment object returned by envpool does not have them, causing theinit()
function to use default values, which are not correct.To Reproduce
Expected behavior
Actual behavior
Screenshots
No screenshots.
System info
Describe the characteristic of your environment:
Additional context
No context.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
The text was updated successfully, but these errors were encountered: