diff --git a/test/test_cost.py b/test/test_cost.py index dae1fa5f70c..d39a1238ea7 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6567,7 +6567,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): else: raise NotImplementedError - loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional) + loss_fn = A2CLoss( + actor, + value, + loss_critic_type="l2", + functional=functional, + return_tensorclass=False, + ) # Check error is raised when actions require grads td["action"].requires_grad = True diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6edcda5c800..df02f05fb6f 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import contextlib import warnings from copy import deepcopy @@ -9,7 +11,7 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -31,6 +33,18 @@ ) +@tensorclass +class A2CLosses: + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class A2CLoss(LossModule): """TorchRL implementation of the A2C loss. @@ -234,6 +248,7 @@ def __init__( functional: bool = True, actor: ProbabilisticTensorDictSequential = None, critic: ProbabilisticTensorDictSequential = None, + return_tensorclass: bool = False, ): if actor is not None: actor_network = actor @@ -289,6 +304,7 @@ def __init__( if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type + self.return_tensorclass = return_tensorclass @property def functional(self): @@ -444,7 +460,7 @@ def _cached_detach_critic_network_params(self): return self.critic_network_params.detach() @dispatch() - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> A2CLosses: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: @@ -465,6 +481,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.critic_coef: loss_critic = self.loss_critic(tensordict).mean() td_out.set("loss_critic", loss_critic.mean()) + if self.return_tensorclass: + return A2CLosses._from_tensordict(td_out) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):