Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.load without weights_only parameter is unsafe #1852

Open
kit1980 opened this issue Feb 27, 2024 · 18 comments
Open

torch.load without weights_only parameter is unsafe #1852

kit1980 opened this issue Feb 27, 2024 · 18 comments
Labels
enhancement New feature or request

Comments

@kit1980
Copy link

kit1980 commented Feb 27, 2024

This is found via https://github.com/pytorch-labs/torchfix/

torch.load without weights_only parameter is unsafe. Explicitly set weights_only to False only if you trust the data you load and full pickle functionality is needed, otherwise set weights_only=True.

stable_baselines3/common/policies.py:176:27

--- /home/sdym/repos/stable-baselines3/stable_baselines3/common/policies.py
+++ /home/sdym/repos/stable-baselines3/stable_baselines3/common/policies.py
@@ -171,11 +171,11 @@
         :param path:
         :param device: Device on which the policy should be loaded.
         :return:
         """
         device = get_device(device)
-        saved_variables = th.load(path, map_location=device)
+        saved_variables = th.load(path, map_location=device, weights_only=True)
 
         # Create policy object
         model = cls(**saved_variables["data"])
         # Load weights
         model.load_state_dict(saved_variables["state_dict"])

stable_baselines3/common/save_util.py:450:33

--- /home/sdym/repos/stable-baselines3/stable_baselines3/common/save_util.py
+++ /home/sdym/repos/stable-baselines3/stable_baselines3/common/save_util.py
@@ -445,11 +445,11 @@
                     file_content.write(param_file.read())
                     # go to start of file
                     file_content.seek(0)
                     # Load the parameters with the right ``map_location``.
                     # Remove ".pth" ending with splitext
-                    th_object = th.load(file_content, map_location=device)
+                    th_object = th.load(file_content, map_location=device, weights_only=True)
                     # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
                     if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
                         # PyTorch variables (not state_dicts)
                         pytorch_variables = th_object
                     else:
@araffin araffin added the duplicate This issue or pull request already exists label Feb 27, 2024
@araffin
Copy link
Member

araffin commented Feb 28, 2024

Duplicate of #1831

@araffin araffin marked this as a duplicate of #1831 Feb 28, 2024
@araffin araffin closed this as not planned Won't fix, can't repro, duplicate, stale Feb 28, 2024
@kit1980
Copy link
Author

kit1980 commented Feb 28, 2024

@araffin you should specify weights_only=False if you need pickle.
Otherwise when soon PyTorch changes the default of weights_only to True, your code will break.

@araffin
Copy link
Member

araffin commented Feb 28, 2024

could you elaborate a bit?
where was that change announced?

@kit1980
Copy link
Author

kit1980 commented Feb 28, 2024

@araffin I don't think there is an announcement, but we're definitely thinking of it.
See this comment pytorch/pytorch#111806 (comment)

@araffin araffin reopened this Feb 28, 2024
@araffin araffin added enhancement New feature or request help wanted Help from contributors is welcomed good first issue Good for newcomers and removed duplicate This issue or pull request already exists labels Feb 28, 2024
@araffin
Copy link
Member

araffin commented Feb 28, 2024

Thanks, btw what is the minimum pytorch version to be able to set weights_only=True?

@kit1980
Copy link
Author

kit1980 commented Feb 29, 2024

@araffin The PR that added the option is pytorch/pytorch#86812, first release with it is PyTorch 1.13.0

@araffin
Copy link
Member

araffin commented Mar 11, 2024

After trying out, we cannot use weights_only=True in SB3 as it breaks some functionality, see #1866.
It would be nice to be able to extend _get_allowed_globals() for the unpickler.

@araffin araffin removed good first issue Good for newcomers help wanted Help from contributors is welcomed labels Mar 11, 2024
markscsmith pushed a commit to markscsmith/OneFiveOne that referenced this issue Apr 18, 2024
…his version now attempts to use torch.load weights_only=True related to DLR-RM/stable-baselines3#1852 😮‍💨
@markscsmith
Copy link
Contributor

So glad I found this!

If you're getting the error:

Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar

when loading a model this is a possible cause. 2.2.1 doesn't have throw the error, but because it doesn't use weights_only=True? stable_baselines3==2.3.0 throws this if I do a model.save and model.load of a PPO model.

I think it's this line here?

th_object = th.load(file_content, map_location=device, weights_only=True)

Maybe a param of some kind that passes weights_only through to the underlying torch.load and let the dev / user decide if they trust the source?

I'll drop a PR when my testing comes out OK! I promise it'll be cleaner than the heap of late night hacking I've been trying to get to play Pokémon ;)

@araffin
Copy link
Member

araffin commented Apr 18, 2024

If you're getting the error:

Please provide a minimal example to reproduce the error.

throws this if I do a model.save and model.load of a PPO model.

I guess you are doing something custom because the tests passes on the CI server.

I think it's this line here?

yes, probably.

@markscsmith
Copy link
Contributor

Please provide a minimal example to reproduce the error.

For sure! Right now the only example I have is my 16mb blob of model, so I'm trying to find the minimal reproduction of that here.

I guess you are doing something custom because the tests passes on the CI server.

Yup! I too think it's something strange I'm doing with my model in particular, because I don't think I'm doing anything fancy with model.save and model.load themselves.

yes, probably.

Excellent! Hopefully my weights_only idea is overkill and it's just a weird quirk of my model that I can adjust, and the PR will just be a test and maybe a warning if someone makes the same mistake as I'm making.

If not, the pull request I'm working on still defaults to weights_only=True and throws a warning if it's overridden to false. My theory is that won't disrupt existing users but would allow people doing weird stuff to be able to load models they trust.

@markscsmith
Copy link
Contributor

@araffin Figured it out! My learning_rate_schedule was using np.pi and np.sin. I've got a test to reproduce now and a pull request ready if my approach is OK! How would you feel about me adding an enhancement to warn if model.save() is called with objects that won't unpickle with weights_only=True?

@araffin
Copy link
Member

araffin commented Apr 18, 2024

thanks for finding out.
I can reproduce with:

from stable_baselines3 import PPO
import numpy as np

model = PPO("MlpPolicy", "CartPole-v1", learning_rate=lambda _: np.sin(1))
model.save("demo")
model = PPO.load("demo")

it comes from policy.optimizer, although I'm a bit confused of why because the optimizer should only receive float from the learning schedule. I guess it has type np.ndarray instead of float and that crashes everything.

I have a simpler fix in your case, cast to float as it is the required type for lr schedule:

from stable_baselines3 import PPO
import numpy as np

model = PPO("MlpPolicy", "CartPole-v1", learning_rate=lambda _: float(np.sin(1)))
model.save("demo")
model = PPO.load("demo")

EDIT: a better PR would be to cast any call to learning_rate() to float

@markscsmith
Copy link
Contributor

EDIT: a better PR would be to cast any call to learning_rate() to float

Oooh, good call! I'll start on that a bit later today! I'm curious about the gymnasium loading issue mentioned earlier as well. Maybe something similar where it's using fancy numpy types?

@araffin
Copy link
Member

araffin commented Apr 18, 2024

Maybe something similar where it's using fancy numpy types?

it's different, the problem occurs because we want to save the complete nn.Module object that contains types (from gymnasium, potentially from numpy) not on the pytorch whitelist.

@Franziac
Copy link

I also ran into the same error as @markscsmith . I'm storing some values (some probably np.nan) in the env that are probably causing the error. I made a workaround by changing weights_only = False.

I understand that the issue originates from me doing something that I probably shouldn't, but I don't really see the harm in doing this:

Maybe a param of some kind that passes weights_only through to the underlying torch.load and let the dev / user decide if they trust the source?

@markscsmith
Copy link
Contributor

markscsmith commented Apr 19, 2024

@araffin I took a look at the logic around the learning_rate solution and am running tests on a fix now. I'll open a new issue and PR for that fix in particular. Thank you again for the help figuring this out!

@Franziac do you have more detail about what you mean by values in the env? Based on araffin's previous comments it might be as simple as converting the types into safely-unpickleable types in the right spot.

That said, arrafin, if you think you're going to get a bunch of bug reports like this from people doing odd stuff with models, I've cleaned up my original changes for weights_only a bit, and am noodling how to do a warning on save of "hey, when you try to unpickle this you're going to get an error!"

My first instinct was "how do I get weights_only=False?" but you led me to ask "why is weights_only=False suddenly necessary?" I imagine this would help @Franziac identify when the special case is needed as well. Given that torch is the one creating that option in the first place, maybe something to do upstream in pytorch?

They seem to be getting a fair number of issues around this as well, and a proactive "hey this object contains weird stuff, weights_only=False will be necessary. Here are the objects that are weird:" might be a nudge to devs to cast to safer types? It aligns well with the advice you gave me that worked for me too!

@araffin
Copy link
Member

araffin commented Apr 19, 2024

@Franziac please provide a minimal working example to reproduce the issue.

@araffin
Copy link
Member

araffin commented Apr 27, 2024

@Franziac please have a look at #1911, it seems to be due to an old version of PyTorch, in the meantime, I will revert the change and release SB3 2.3.2 that should solve the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants