Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMERG] Prototyping tensorclass for losses #1892

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6567,7 +6567,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
functional=functional,
return_tensorclass=False,
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down
22 changes: 20 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand All @@ -31,6 +33,18 @@
)


@tensorclass
class A2CLosses:
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy


class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.

Expand Down Expand Up @@ -234,6 +248,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
return_tensorclass: bool = False,
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -289,6 +304,7 @@ def __init__(
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
self.return_tensorclass = return_tensorclass

@property
def functional(self):
Expand Down Expand Up @@ -444,7 +460,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand All @@ -465,6 +481,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
if self.return_tensorclass:
return A2CLosses._from_tensordict(td_out)
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down