Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritized experience replay #1622

Open
wants to merge 25 commits into
base: master
Choose a base branch
from

Conversation

AlexPasqua
Copy link
Contributor

@AlexPasqua AlexPasqua commented Jul 23, 2023

Description

Implementation of prioritized replay buffer for DQN.
Closes #1242

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

In accordance with #1242

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@AlexPasqua
Copy link
Contributor Author

@araffin could you (or anyone) please have a look at the 2 pytype errors? I don't quite understand how to fix them

@araffin araffin added the Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;) label Aug 10, 2023
@araffin araffin removed the Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;) label Sep 4, 2023
@araffin araffin self-requested a review September 4, 2023 08:52
@AlexPasqua AlexPasqua marked this pull request as ready for review September 29, 2023 09:58
@AlexPasqua
Copy link
Contributor Author

Thanks @araffin !
Out of curiosity, may I ask why the switch between torch and numpy for the backend?

@araffin
Copy link
Member

araffin commented Sep 29, 2023

Thanks @araffin ! Out of curiosity, may I ask why the switch between torch and numpy for the backend?

to be consistent with the rest of the buffers and because PyTorch is not needed here (no gpu computation needed).

@AlexPasqua
Copy link
Contributor Author

AlexPasqua commented Sep 30, 2023

Hello @araffin ,
as you moved the code to "common", I suppose you plan to make it usable in algorithms other than DQN. At this point, wouldn't it be clearer to put the code into common/buffers.py? Let me know, and in case, I will move it there.

AlexPasqua and others added 5 commits September 30, 2023 19:40
@araffin
Copy link
Member

araffin commented Oct 2, 2023

At this point, wouldn't it be clearer to put the code into common/buffers.py?

yes probably, but the most important thing for now is to test the implementation (performance test, check we can reproduce the results from the paper), document it and add additional tests/doc (for sumtree for instance).

@araffin
Copy link
Member

araffin commented Oct 4, 2023

performance test, check we can reproduce the results from the paper

After some initial test on Breakout following hyperparameters from the paper, the run didn't improve or worsen DQN performance so far...
I will try on other envs (it would be nice if you could help).

@AlexPasqua
Copy link
Contributor Author

After some initial test on Breakout following hyperparameters from the paper, the run didn't improve or worsen DQN performance so far... I will try on other envs (it would be nice if you could help).

Thanks for starting to test it!
These days I'm travelling, and also writing a paper after work, but I'll try to squeeze some tests in

@AlexPasqua
Copy link
Contributor Author

AlexPasqua commented Nov 2, 2023

@araffin I've also done some initial tests and it looks like PER might lead to a slightly faster convergence, for example on cartpole, but nothing super evident unfortunately.
Next I'd like to properly reproduce some of the paper's experiment, but computational power could become a bit of an issue for me

Comment on lines +212 to +225
# Special case when using PrioritizedReplayBuffer (PER)
if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
# TD error in absolute value
td_error = th.abs(current_q_values - target_q_values)
# Weighted Huber loss using importance sampling weights
loss = (replay_data.weights * th.where(td_error < 1.0, 0.5 * td_error**2, td_error - 0.5)).mean()
# Update priorities, they will be proportional to the td error
assert replay_data.leaf_nodes_indices is not None, "Node leaf node indices provided"
self.replay_buffer.update_priorities(
replay_data.leaf_nodes_indices, td_error, self._current_progress_remaining
)
else:
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexPasqua Ideally, we'd like to be able to associate it with all off-policy algo's without adaptation, but I don't see a simple way of doing it at this stage.
Also related, we had discussed not modifying DQN: Stable-Baselines-Team/stable-baselines3-contrib#127 (comment)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in this PR. Since every algo-specific train method includes a replay_buffer.sample line, couldn't we just additionally add a replay_buffer.update line? The update function could take in the current and target q values whenever a value function is present or maybe even all the local variables. It would do nothing for the vanilla replay buffer. Would this be an acceptable modification?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment!
How do you handle the loss in your proposal?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want this to work for general off-policy algorithms, we could update the ReplayBufferSample-like classes to additionally include an importance_sampling_weight attribute which would be updated from the replay_buffer.update method.

Then I see two ways to handle the loss under this interface:

  1. Estimate TD error from the loss as such:
losses = loss_fn(current_q_values, target_q_values, reduction='none')

# e.g. If loss is L2, then it's basically th.sqrt(loss). If loss is L1, td_error = loss
td_error = importance_sampling_weight * function_to_approx_td_error(losses)  

loss = losses.mean()

Obviously the downside of this is that it requires hand engineering for the different types of loss functions or priority metrics.

  1. Make any value-based train methods "td-error" centric in the sense that we always compute td_error = importance_sampling_weight * th.abs(current_q_values - target_q_values) first, then the loss loss = loss_fn(td_error). The downsides of this approach is that we cant use the pytorch api for computing the loss, and would have to write functions for those.

Either approach requires computing a td_error variable which unfortunately requires somewhat intrusive code changes. What do you think?

Copy link
Member

@araffin araffin May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe to make things clearer: my plan is not to have PER for all algorithms, mainly for two reasons:

  1. Keep the code concise (in fact, I would like to have RAINBOW and keep vanilla DQN, see [Feature Request] RAINBOW #622)
  2. I don't think it works for entropy-RL algorithms (SAC and derivates), so it would be limited to DQN/QR-DQN and TD3

If the users really want PER in other algo, they would take inspiration from a reference implementation in SB3 and integrate it (the same way we don't provide maskable + recurrent PPO at the same time).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"just" yes, I would be happy to receive such PR =)
the main thing is to benchmark the implementation and reproduce the published results.
This PR is also still open because I was not satisfied by the result of DQN + PER (I couldn't see significant different with respect to DQN).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I had in mind was to implement CNN for SBX (https://github.com/araffin/sbx) in order to iterate faster and check the PER, but I had no time to do so until now...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we implement the toy environment from figure 1 of https://arxiv.org/pdf/1511.05952 as the PER benchmark? It would be a simpler initial check for correctness than the Atari environments

Copy link
Member

@araffin araffin May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The toy environment can be a start for fast iteration and debugging, but what we learned in the past is that subtle bugs only show up when doing more complex task (see #48 and #47 where we found bugs like PyTorch and TF RMSProp are not the same)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, will definitely work towards it!

@richardjozsa
Copy link

Just a comment, I've tested this implementation with QR-DQN with Vecenv multiple environment but it fails because of the missing part.

But good job to start the work on it! I hope it will be merged soon! 👍

@araffin araffin mentioned this pull request May 24, 2024
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Prioritized Experience Replay for DQN
6 participants