Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO train code refactor for checkpointing and curriculum compatibility #211

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

btnorman
Copy link

@btnorman btnorman commented Jul 26, 2022

I refactored the PPO/train code to increase its compatibility with check-pointing and environment variation, while keeping backwards compatibility

This splits the train function into multiple functions:

  • make_train_space: a function that creates the training functions, e.g. training_epoch_with_timing, evaluator, etc... and wraps them in a returned simple namespace
  • init_training_state: a function that initializes the training state
  • init_env_state: a function that initializes the environment state
  • and run_train: a function that performs a training run, when passed a training state, environment state, and train space

This enables the training functions to be run multiple times with jit compilation only occurring once.
This also adds:

  • train: a function that works identically to the previous train function, by calling the above described functions
  • checkpoint_train: a simple check pointing version of train, that should be compatible with preemption.

The general aim is that run_train should enable people to easily make their own check pointing and environment variation / curriculum generation on top of the Brax PPO training code, without having to modify the Brax PPO code internally.

This is my first pull request, so let me know if I have made any rookie mistakes! And thanks for the great physics engine!

Fyi, I have not been able to test on in an environment with multiple processors, e.g. a TPU slice

The train function was partitioned into several parts.

make_train_space: a function that creates the training functions, e.g.  training_epoch_with_timing, evaluator, etc...
init_training_state: a function that initializes the training state
init_env_state: a function that initializes the environment state
and run_train: a function that performs a training run, when passed a training state, environment state, and train space

The benefit of creating a train space is that the training functions can be run multiple times with jit compilation only occurring once.
Combined with partitioning the training loop from training state initialization this increases compatibility for curriculum generation and environment variation, as the environment parameters can be changed, without repeated jit compilation.

A simple check pointing version of the train function was created that makes use of the above partition. Finally, for backwards compatibility, the train function still exists and continues to function as before, by calling the above detailed functions.
Ensured that max_devices_per_host is used for all train functions.
@m-orsini
Copy link

Hi btnorman,

I'm very glad to see you're showing interest in Brax!

The idea of the agents directory is to show example implementations of popular algorithms with Brax. We know it doesn't cover all usages, and so it's expected that people will fork those examples to get the algorithm to do what they want.

We expect different users will have different opinions regarding Check-pointing for example - as a result we prefer to leave it unimplemented.

There is also that we want the interface for all algorithms to be as close as possible, so I don't believe introducing the abstractions you propose only for PPO can go through.

For those reasons, I think it makes sense this PR stays out of the main branch.

Have fun with Brax!

@btnorman
Copy link
Author

Hi! Thanks for looking over it.

I want to make sure I have not misrepresented the intention of the proposed contribution!

The contribution comes in two separate parts.

  1. Modifying the structure of the PPO code to more easily allow Brax users to build on top of it, without Brax users having to modify the internals of the Brax PPO code.

    To illustrate the value, imagine a user who solely wants to implement check pointing, and is otherwise happy with the out of the box PPO code. At the moment, this user has two options. A) they can copy the Brax code and introduce check pointing to the internals, or B) they can use a different PPO implementation, either their own or e.g. PyTorch.

    Both of these options, A) and B) come with significant overheads. Copying the Brax code and modifying it first involves parsing the logic and structure of the code, and to someone unfamiliar with Jax that is a lot of work. It also involves consulting multiple different Brax files to understand how all the code fits together. Using a different PPO implementation is also a lot of work as there is the overhead of working out how Brax and that implementation work well together. Both of these are also error prone as if one introduces an error or typo, the RL training might subtly break.

    By making the proposed change, we introduce a third option for this user. C) they can take the abstractions provided, and build on top of them. This is significantly less work, as it needs minimal understanding of Jax and minimal consultation of other Brax files.

  2. Introducing a very simple example implementation of checkpointing that demonstrates how someone could use the components introduced to make their own checkpointing code easily, building onto of the Brax code, without having to write their own PPO logic.

    This relates to the hypothetical user who wants their custom checkpointing code. They can copy the example implementation of checkpointing, which only involves file loading, and without any understanding of Jax, or the need to consult any other Brax files, modify it to suit their checkpointing needs.

    This second part is only meant to be illustrative. If the changes went ahead then a checkpointing example might be better placed in a notebook instead.

If the aim of making it easier to build on top of existing Brax training code seems valuable, but the proposed implementation lacking, perhaps there is something else we can do, and I would be happy to help!

If you decide the change does seem valuable then I would be happy to modify the other algorithms so that they are all consistent.

Fyi, to allay a potential concern, in using the checkpointing code with these abstractions there is very low overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants