Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] CrossQ #2033

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

[Algorithm] CrossQ #2033

wants to merge 11 commits into from

Conversation

BY571
Copy link
Contributor

@BY571 BY571 commented Mar 21, 2024

Description

Adding CrossQ

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Mar 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2033

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures, 2 Unrelated Failures

As of commit 75d4cee with merge base f613eef (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 21, 2024
torchrl/objectives/crossq.py Outdated Show resolved Hide resolved
@BY571
Copy link
Contributor Author

BY571 commented Mar 21, 2024

Performance with separate target_computation looks good:

image

But we need to check for speed. It should be similar to our sac implementation.

@BY571 BY571 marked this pull request as ready for review March 26, 2024 18:40
# self.qvalue_network_params,
# ).get(self.tensor_keys.state_action_value)

combined = torch.cat(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved the previous issue but the sequential forward pass for current state and next state values is still faster than the combined. Cat and splitting might be slowing down the computation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the commented vmaps, we can make them run faster eventually

@vmoens vmoens added the new algo New algorithm request or PR label Apr 8, 2024
# Conflicts:
#	.github/unittest/linux_examples/scripts/run_test.sh
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this
There are just a couple of things to fix before merging

sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not use private attributes in examples. Let's make qvalue_loss a public method if that is needed

sampled_tensordict
)
actor_loss = actor_loss.mean()
alpha_loss = loss_module._alpha_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

)
actor_loss = actor_loss.mean()
alpha_loss = loss_module._alpha_loss(
log_prob=metadata_actor["log_prob"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fact that the example requires that much knowledge about the way the loss works is a bit worrying - the script should be immediate.
Is there a version of this were alpha_loss just takes the metadata dict?

"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
"activation_class": get_activation(cfg.network.actor_activation),
"norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch
"norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275)

"num_cells": cfg.network.critic_hidden_sizes,
"out_features": 1,
"activation_class": get_activation(cfg.network.critic_activation),
"norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch
"norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275)

qvalue_network (TensorDictModule): Q(s, a) parametric model.
This module typically outputs a ``"state_action_value"`` entry.

num_qvalue_nets (integer, optional): number of Q-Value networks used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_qvalue_nets (integer, optional): number of Q-Value networks used.
Keyword Args:
num_qvalue_nets (integer, optional): number of Q-Value networks used.


action: NestedKey = "action"
state_action_value: NestedKey = "state_action_value"
log_prob: NestedKey = "_log_prob"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used I think

def _cached_detached_qvalue_params(self):
return self.qvalue_network_params.detach()

def _actor_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make public


return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()}

def _qvalue_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make public

# self.qvalue_network_params,
# ).get(self.tensor_keys.state_action_value)

combined = torch.cat(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the commented vmaps, we can make them run faster eventually

@vmoens
Copy link
Contributor

vmoens commented Jun 12, 2024

@BY571 we should also add it to the sota benchmarks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. new algo New algorithm request or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants