Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enable parameter reset in loss module #2017

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

BY571
Copy link
Contributor

@BY571 BY571 commented Mar 18, 2024

Description

Allows to reset the parameters in the loss module.

Copy link

pytorch-bot bot commented Mar 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2017

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 1 Unrelated Failure

As of commit 4b29473 with merge base 87f3437 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 18, 2024
@BY571 BY571 changed the title [Feature] enable parameter reset in loss module [Feature] Enable parameter reset in loss module Mar 18, 2024
@vmoens vmoens added the enhancement New feature or request label Mar 18, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

We'll need tests for the feature.

How do we handle the target parameters?

Wouldn't something like this be a bit more robust?

from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn

module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))

value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")

with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)

module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
If None, all modules with names ending in "_params" will be reset.
init_func (Optional[Callable]): A function to initialize the parameters.
If None, the parameters will be initialized with uniform random values between -1 and 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems very unlikely that anyone would want to use that init IMO. Shouldn't we use the init method from the corresponding nn.Module if there is?

def reset_parameters(
self,
module_names: Optional[List[Parameter]] = None,
init_func: Optional[Callable] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
init_func: Optional[Callable] = None,
init_func: Callable[[torch.Tensor], None] | None = None,

@@ -363,6 +364,35 @@ def reset(self) -> None:
# mainly used for PPO with KL target
pass

def reset_parameters(
self,
module_names: Optional[List[Parameter]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module_names: Optional[List[Parameter]] = None,
module_names: List[Parameter] | None = None,

"""Reset the parameters of the specified modules.

Args:
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
module_names (list of nn.Parameter, optional): A list of module names to reset the parameters for.

Args:
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
If None, all modules with names ending in "_params" will be reset.
init_func (Optional[Callable]): A function to initialize the parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
init_func (Optional[Callable]): A function to initialize the parameters.
init_func (Callable[[torch.Tensor], None]): A function to initialize the parameters.

else:
params_2_reset = [name + "_params" for name in module_names]

def _reset_params(param):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having one single reset function will be hard to handle, we need a way to tie the reset function and the module.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

We'll need tests for the feature.

How do we handle the target parameters?

Wouldn't something like this be a bit more robust?

from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn

module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))

value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")

with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)

@BY571
Copy link
Contributor Author

BY571 commented Mar 20, 2024

Thanks for this!

We'll need tests for the feature.

How do we handle the target parameters?

Wouldn't something like this be a bit more robust?

from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn

module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))

value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")

with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)

I like the solution! But we are accessing the parameters directly in your example so we would need to define a reset function manually, which I think is perfectly fine because then the user has to decide the way how to reset weights and biases:

def reset_parameters(params):
    """ User specified resetting function depending on their needs for initialization """
    if len(params.shape) > 1:
        # weights
        nn.init.xavier_uniform_(params)
    elif len(params.shape) == 1:
        # biases
        nn.init.zeros_(params)
    else:
        raise ValueError("Unknown parameter shape: {}".format(params.shape))
  
with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)

And for handling the target_network_params I think we could simply do something like:

loss.target_value_network_params.update(loss.value_network_params)

What do you think? I think we can close the draft. But we might want to mention the way to reset parameters somewhere in the docs.

@vmoens
Copy link
Contributor

vmoens commented Mar 20, 2024

loss.target_value_network_params.update(loss.value_network_params)

This won't work because the target params are locked (you can't update them). They're locked because we want to avoid this kind of operation :)
You should update the data inplace:

loss.target_value_network_params.apply(lambda dest, src: dest.data.copy_(src), loss.value_network_params)

@vmoens
Copy link
Contributor

vmoens commented Mar 20, 2024

def reset_parameters(params):
    """ User specified resetting function depending on their needs for initialization """
    if len(params.shape) > 1:
        # weights
        nn.init.xavier_uniform_(params)
    elif len(params.shape) == 1:
        # biases
        nn.init.zeros_(params)
    else:
        raise ValueError("Unknown parameter shape: {}".format(params.shape))
  
with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)

Unfortunately this isn't very generic
(1) all tensors have a data attribute, even buffers. By doing this you will also use Xavier init on batch-norm buffers if they're 2d
(2) If the model has a mixture of linear, conv and other layers it's going to be hard to have a fine grained control over the params being updated.

Not all modules are "weights" and "biases" and "biases" can be 2d (my point is: the dimension is a very indirect determinator of the tensor role in a model)

The way I usually see this work is to use the module reset_parameters if there is one, which provides a better control over difference in initialization methods.

Maybe we could allow the user to pass a reset function, but in that case we don't even need to re-populate the module (we can just do tensordict.apply(reset)). Note that you could also do

def reset(name, tensor):
    if name == "bias":
        tensor.data.zero_()
    if name == "weight":
        nn.init.xavier_uniform_(tensor)
tensordict.apply(reset, named=True)

which is more straightforward IMO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants