New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama Attention Call should not pass **kwargs #30523
Comments
Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉 |
I will open a PR after cataloguing all the models that have this issue. Gptneox also has this issue. Reproducer is to wrap a model in FSDP and then do a forward on any data. |
Yep, can confirm I also see the same issue with LLaMA-3-8b-Instruct with FSDP + Gradient Checkpointing. The Yi series of models also have this issue, I just checked. And it makes perfect sense since they follow the LLaMA architecture. |
We'll remove the kwargs! cc @zhenglongjiepheonix who is working on something related! |
System Info
transformers
version: 4.40.1Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Wrapping a LlamaModel with FSDP results in the following error during a forward pass;
This occurs because we are passing **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608
If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77
Expected behavior
Remove line 749 or add **kwargs to forward().
The text was updated successfully, but these errors were encountered: