Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Pipeline to get device from model #30534

Merged
merged 12 commits into from May 13, 2024
Merged

Conversation

faaany
Copy link
Contributor

@faaany faaany commented Apr 29, 2024

What does this PR do?

import torch 
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, pipeline

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
print(model.device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe.model.device)
results = pipe("He's a dreadful magician and")

Currently, the code above will give an output of

cuda:0
cpu

But this is not OK: when users have moved the model to CUDA, Pipeline should not move the model back to CPU without showing any message. This PR makes it possible to let the model stay on its original device. Below is the results after this PR:

cuda:0
cuda:0

@Narsil and @muellerzr

@faaany
Copy link
Contributor Author

faaany commented Apr 29, 2024

@yao-matrix

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Could you add a test?

@muellerzr
Copy link
Contributor

@faaany are we sure that model.device is a thing across all these frameworks?

At most I see ModuleUtilsMixin has device which is PyTorch specific (it gets added to AutoModel, but I'd like to verify the locations of TF and Flax backends having these capabilities to grab the model device. Otherwise we don't really want just None here IMO

@faaany
Copy link
Contributor Author

faaany commented Apr 30, 2024

Thanks for adding this!

Could you add a test?

sure, in which test file should I put this test?

@faaany
Copy link
Contributor Author

faaany commented Apr 30, 2024

@faaany are we sure that model.device is a thing across all these frameworks?

At most I see ModuleUtilsMixin has device which is PyTorch specific (it gets added to AutoModel, but I'd like to verify the locations of TF and Flax backends having these capabilities to grab the model device. Otherwise we don't really want just None here IMO

Good point! Yes, I know that Flax model doesn't have "device". How about moving it inside if is_torch_available() and self.framework == "pt": ? I have updated my code.

Furthermore, I removed the self.device is not None check, because it will never be None. And I also added the logic that model shouldn't be moved, if the model is already on device.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating and handling the torch case!

Only request is to add a test.

@muellerzr could you give a quick review as you correctly spotted and highlighted the torch vs. other frameworks case?

@faaany
Copy link
Contributor Author

faaany commented May 11, 2024

Hi @amyeroberts, sorry for the late response. We had a long holiday here in China. Unit tests are added. Let me explain more about in detail:

There are 3 possibilities for model.device:
a1. user passes device_map to from_pretrained
a2. user doesn't pass device_map to from_pretrained
a3. user manually moves the model to a certain device with to(device) after model is loaded with from_pretrained

There are 2 possibilities for pipeline.device:
b1. user passes device to pipeline
b2. user doesn't pass device to pipeline

Sincea2&b2 is trivial, my unit tests cover the cases a1&b1, a1&b2, a3&b1 and a3&b2. Pls have a review, thx!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for adding the tests and the explanation!

cc @muellerzr For a final double check to make sure this makes sense with accelerate

tests/pipelines/test_pipelines_common.py Outdated Show resolved Hide resolved
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better, thanks! Agreed post Amy's nit :)

@faaany
Copy link
Contributor Author

faaany commented May 13, 2024

Thanks for the review! @amyeroberts @muellerzr

@amyeroberts amyeroberts merged commit 69d9bca into huggingface:main May 13, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants