Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subsume part of System inside State; EDIT: Or add Options to reset #446

Open
joeryjoery opened this issue Jan 30, 2024 · 11 comments
Open

Subsume part of System inside State; EDIT: Or add Options to reset #446

joeryjoery opened this issue Jan 30, 2024 · 11 comments
Assignees

Comments

@joeryjoery
Copy link

For domain randomization it is not particularly easy to vmap over different System values. For example the gravity values, or the elasticity. Preferably you should be able to do this in env.reset but right now this is not possible as self.sys is a global variable in the Env namespace.

Right now my hacky workaround is to Mock the Brax environment with my custom PyTree-like dataclass so I can modify the env.sys values in a functionally pure way inside the reset function.

It would be nice if brax could expose part of the sys dict/ namespace as a pure argument to env.reset and env.step (e.g., as part of the state).

@joeryjoery
Copy link
Author

Wanted to add an example of another workaround: https://github.com/automl/CARL.

In this library for meta-RL, instead of batching environments on the GPU which Brax should support, the CARL-brax environments create VectorizedWrappers from Gymnasium in order to run multiple System variations simultaneously. Which kind of defeats the purpose of GPU parallelization....

@btaba
Copy link
Collaborator

btaba commented Jan 31, 2024

Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation.

class DomainRandomizationVmapWrapper(Wrapper):

Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the state.info). If you manage to get the same training performance out of it, please send it our way!

@joeryjoery
Copy link
Author

Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation.

class DomainRandomizationVmapWrapper(Wrapper):

Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the state.info). If you manage to get the same training performance out of it, please send it our way!

Hey thanks for the reply. A big obstacle right now in trying to implement something like this is that the pipeline.init and pipeline.step functions are quite rigid. They only receive self, q, qd, _debug as arguments.

So I'm trying to work around this by doing dependency injection for self by converting it into a PyTree such that I can do jax-transforms on pipeline.init etc.. But mocking this object is causing quite a few problems since I'm running into unforeseen dependencies. For this reason I think this approach is not great as this will definitely lead to problems later on.

@btaba Could the pipeline.init and pipeline.apply functions perhaps be extended to receive an optional options dictionary? This would require the API to propagate the options in reset and step from wrappers to base (i.e., like the Gymnasium implementation).

In principle, if these are none then the performance stays the same, and if I want to provide it with options then I can wrap the pipeline module with my custom function that modifies the self.sys.

What do you think?

@btaba
Copy link
Collaborator

btaba commented Feb 1, 2024

Hi @joeryjoery ,

I'm not quite following why you want to add extra args to pipeline.init and pipeline.step. Does something like this not work: jax.vmap(pipeline.init, in_axes=[custom_in_axes, None, None])(sys, q, qd) ?

@joeryjoery
Copy link
Author

Hey, yes this works. But it's not the problem.

The issue is that I have no easy way to propagate sys variations to that point (at least not in a way that is jittable). So for example, the Ant environment has a reset which looks something like this,

def reset(self, rng: jax.Array) -> State:
  """Resets the environment to an initial state."""
  rng, rng1, rng2 = jax.random.split(rng, 3)
  
  ...
  
  pipeline_state = self.pipeline_init(q, qd)
  obs = self._get_obs(pipeline_state)
  ...

Now suppose I want to wrap Ant I do not have direct access to the self.pipeline_init call. So I cannot modularly jax.vmap(pipeline.init, ...

A way to solve this is to allow options, for example,

def reset(self, rng: jax.Array, *, options: dict | None = None) -> State:
  """Resets the environment to an initial state."""
  rng, rng1, rng2 = jax.random.split(rng, 3)
  
  ...
  
  pipeline_state = self.pipeline_init(q, qd, options=options)  # Pass along here
  obs = self._get_obs(pipeline_state)
  ...

In this way, I can wrap env._pipeline with a function like,

my_env.pipeline_init = my_wrapped_init

def my_wrapped_init(self, q, qd, *, options: dict | None = None):
  sys = self.sys  

  if options is not None:
    variations = some_sampling_function(options)  # returns dict
    sys = self.sys.replace(**variations)
    return jax.vmap(self._pipeline.init, in_axes=(0, None, None, None))(sys, q, qd, self._debug)

  return self._pipeline.init(self.sys, q, qd, self._debug)

@joeryjoery joeryjoery changed the title Subsume part of System inside State Subsume part of System inside State; EDIT: Or add Options to reset Feb 1, 2024
@btaba
Copy link
Collaborator

btaba commented Feb 2, 2024

Comments and questions on the proposed changes:

[1] Subsume part of System inside State: You can do this already by adding System to state.info, and re-writing your env code to use state.info['sys'] instead of self.sys. How performant is that implementation for RL workloads? Then we can discuss a potential API change

[2] Add Options to reset: Strong preference here to add your logic to a wrapper, and to split out the vmap case from the non-vmap case into distinct wrappers. It looks like your proposal is similar to the DomainRandomizationVmapWrapper except you want to do the sys.replace at pipeline.init/pipeline.step time? Does this mean that the env.reset and env.step logic won't be accessing the same randomized version of sys?

@joeryjoery
Copy link
Author

Hey thanks a lot for continuing the discussion.

TLDR; I was overthinking this, and the easy solution is indeed a slight modification of DomainRandomizationVmapWrapper.

  1. The problem with the current DomainRandomizationVmapWrapper is that the randomization is done in the __init__ and not in the reset. If I want to resample variations at every call to reset I instead have to reinstantiate the class, which would mean recompiling reset and step which is costly.

  2. What I did now is make randomization_fn dependent on a random key and call it inside reset, the sampled variations are then replaced inside System and stored inside State.info. These only contain the varied fields so that we don't redundantly pass around data.

In my implementation I also do not include vmap as I think it is much easier to just vmap over the DomainRandomization wrapper. I have not tested performance, but the code is much more readable.

This is what I propose:

class DomainRandomization(brax.envs.Wrapper):
    """Wrapper for Procedural Domain Randomization."""
    
    def __init__(
        self, 
        env: Env, 
        randomization_fn: Callable[[System, jax.Array], System]
    ):
        super().__init__(env)
        self.randomization_fn = randomization_fn

    def env_fn(self, sys: System) -> Env:
        env = self.env
        env.unwrapped.sys = sys
        return env
    
    def reset(self, rng: jax.Array) -> State:
        key_reset, key_var = jax.random.split(rng)
        
        sys = self.env.unwrapped.sys
        variations = self.randomization_fn(sys, key_var)

        new_sys = sys.replace(**variations)
        new_env = self.env_fn(new_sys)
        
        state = new_env.reset(key_reset)
        state = state.replace(info=state.info | {'sys_var': variations})
        
        return state
        
    def step(self, state: State, action: jax.Array) -> State:

        variations = state.info['sys_var']

        sys = self.env.unwrapped.sys
        new_sys = sys.replace(**variations)

        new_env = self.env_fn(new_sys)
        state = new_env.step(state, action)
        
        state = state.replace(info=state.info | {'sys_var': variations})
        
        return state

example usage,

def viscosity_randomizer(system: System, key: jax.Array) -> dict[str, Any]:
    return {'viscosity': jax.random.uniform(key, system.viscosity.shape)}

env = envs.create(
    env_name='ant',
    episode_length=1000,
    action_repeat=1,
    auto_reset=True,
    batch_size=None,
)

wrap = DomainRandomization(env, viscosity_randomizer)

s0 = jax.jit(wrap.reset)(jax.random.key(0))
s1 = jax.jit(wrap.reset)(jax.random.key(321))

print(s0.info['sys_var'], s1.info['sys_var'])
>> {'viscosity': Array(0.10536897, dtype=float32)} {'viscosity': Array(0.3906865, dtype=float32)}


print(w.unwrapped.sys.viscosity)
>> Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
print(w.default_sys.viscosity)
>> 0.0

Or composing with the VmapWrapper,

sbatch = jax.jit(brax.envs.wrappers.training.VmapWrapper(wrap).reset)(
    jax.random.split(jax.random.key(0), 5)
)
print(sbatch.info['sys_var'])
>> {'viscosity': Array([0.6306313 , 0.5778805 , 0.64515114, 0.95315635, 0.24741197],      dtype=float32)}

@joeryjoery
Copy link
Author

It's not really easy to show that this implementation works here, but if you visualize the results using the code shown in the Colab, you can see that it indeed randomizes the System variables per random key.

https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb#scrollTo=4hHuDp53e4VJ

I also haven't tested performance for RL training. But it's guaranteed faster than using the current DomainRandomizationVmapWrapper due to its non-pure implementation for randomization_fn if your goal is to randomize at every reset call.

@btaba
Copy link
Collaborator

btaba commented Feb 2, 2024

Hi @joeryjoery

I think we tried a version of this implementation. A few comments:

[1] Can you update your impl to make it work for nested fields in sys? You can probably use tree_replace
[2] IIRC passing these extra vars in the info were costly for an RL workload. Can you compare performance with your current version vs. the version at HEAD to see where we're at, and randomize a few more parameters (esp. ones that scale with nv nq ngeom)? Maybe try this on humanoid. So you'd potentially be passing (batch_size, ngeom) parameters in the state.info

FWIW, the impl at HEAD, despite creating a static batch of sys, is enough for sim2real transfer on a quadruped. You can also do multiple resets in training like here (if you're concerned about the static part):

for _ in range(max(num_resets_per_eval, 1)):
# optimization
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state, training_metrics) = (
training_epoch_with_timing(training_state, env_state, epoch_keys)
)
current_step = int(_unpmap(training_state.env_steps))
key_envs = jax.vmap(
lambda x, s: jax.random.split(x[0], s),
in_axes=(0, None))(key_envs, key_envs.shape[1])
# TODO: move extra reset logic to the AutoResetWrapper.
env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state

@joeryjoery
Copy link
Author

Hey!

For 1) I was working on something like this, but didn't quite finish today, will update it later. What do you mean with tree_replace is it a private brax api? I was more thinking along the lines of mocking System with a nested dictionary.

For 2), I don't think there is a way around this, we are passing around more data. If the variations are small (like just the viscosity or gravity), then I'd imagine that this is negligible really, but this can grow yes for something like Humanoid and mass or geoms variations. Though, there are some optimizations here I'd imagine.

I'm not suggesting that the other DomainRandomizationVmapWrapper is wrong, if this works well for sim2real that's amazing.

However, for me, I'm specifically looking at fulfiling my research assumptions as well as I can. This assumes random environments at every sampled trajectory, which makes learning a good policy also severely more difficult. Also, In my experiments the data-collection is rarely the bottleneck and moreso the learner I've found (at least for my very specific use-case; meaning PPO with a recurrent network architecture that also does internal matrix inversions).

If I find the time I'll try run the default agent with the current domain-randomization and the one I posted.

@btaba
Copy link
Collaborator

btaba commented Feb 8, 2024

Hi @joeryjoery , tree_replace can be found here:

def tree_replace(

Thanks for the context on [2], I recommend using your own wrapper (for ensuring sampling a new system for every trajectory), looks like you're pretty close to a more general version with the implementation above! Let us know if you have any trouble and please feel free to share any findings (or open a PR)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants