Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase modularization of training wrappers #422

Open
vyeevani opened this issue Nov 12, 2023 · 4 comments
Open

Increase modularization of training wrappers #422

vyeevani opened this issue Nov 12, 2023 · 4 comments
Assignees

Comments

@vyeevani
Copy link

vyeevani commented Nov 12, 2023

The training wrappers for auto reset and episode wrapper are leaking info to each other.

This is a bigger problem if people want to stack their own wrappers. For example, I'd like to write a meta-episode wrapper that takes multiple episodes and aggregates them into a single meta episode for use in a meta-RL setting. The wrapper that I'm writing needs to track the number of episodes so that it can do a meta episode reset when it reaches some watermark. However, the auto reset would break this since it wouldn't reset the meta wrapper under it.

At a high level, I'd like to separate the state of the auto reset wrapper from the environments that it's wrapping. I propose to do this by caching the initial state of the environment and the current state of the environment in the info, and only working on that.

  1. In reset, get the base state and store it in info with two separate keys: initial_base_state, current_base_state
  2. In reset, take the initial_base_state's: observation, reward, and done and package it along with the info from step (1)
  3. In step, if you are done, then return the same aggregated state as step (2)
  4. In step, if not done, then return the evolved state by updating the current base state.

Note, through this process, you never need to return the pipeline state through the state itself.

@vyeevani
Copy link
Author

vyeevani commented Nov 12, 2023

Haven't tested this, but I'm thinking something like this:

class AutoResetWrapper2(Wrapper):
  """Automatically resets Brax envs that are done."""
  def reset(self, rng: jax.Array) -> State:
    base_state = self.env.reset(rng)
    info = base_state.info.copy()
    info.update({
        'initial_base_state': base_state,
        'current_base_state': base_state
    })
    
    return State(
        pipeline_state=base_state.pipeline_state,
        obs=base_state.obs,
        reward=base_state.reward,
        done=base_state.done,
        metrics=base_state.metrics,
        info=info
    )

  def step(self, state: State, action: jax.Array) -> State:
    initial_base_state = state.info['initial_base_state']
    current_base_state = state.info['current_base_state']
    next_base_state = self.env.step(current_base_state, action)

    done = next_base_state.done
    def where_done(x, y):
      return jp.where(done, x, y)

    info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
    info.update ({
        'initial_base_state': initial_base_state,
        'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
    })

    return State(
        pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
        obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
        reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
        done=next_base_state.done,
        metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
        info=info
    )

@vyeevani
Copy link
Author

This allows us to pass back information from the wrapped things up to the clients that are expecting them (episode wrappers truncated for instance) while at the same time caching this stuff so that we can reuse it during the reset.

@vyeevani
Copy link
Author

Tested the above (not very rigorously), seems to work

@btaba
Copy link
Collaborator

btaba commented Dec 8, 2023

Hi @vyeevani

Thanks for the proposal, indeed there is logic leaking into reset from the episode wrapper.

If you changed first_pipeline_state to first_state (which would contain the first State object) in the impl at HEAD, would that suffice for your use-case? Why do you need to store current_base_state in the info as well?

My only concern beyond cleaner semantics is how this affects performance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants