Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 with LlamaForSequenceClassification - Shape mismatch error #30548

Open
2 of 4 tasks
parasurama opened this issue Apr 29, 2024 · 7 comments
Open
2 of 4 tasks

Llama3 with LlamaForSequenceClassification - Shape mismatch error #30548

parasurama opened this issue Apr 29, 2024 · 7 comments

Comments

@parasurama
Copy link

System Info

  • transformers version: 4.40.0
  • Platform: Linux-5.14.0-284.40.1.el9_2.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.11
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.3
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm getting a shape mismatch error when loading the LLama3 model with LlamaForSequenceClassification

model_path = "meta-llama/Meta-Llama-3-8B"
    
  max_length = 4096
  quant_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype=torch.bfloat16,
      bnb_4bit_use_double_quant=False,
  )

  peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, 
                           inference_mode=False, 
                           lora_alpha=LORA_ALPHA,
                           lora_dropout=LORA_DROPOUT,
                           r=LORA_R,
                           bias="none",
                           target_modules=LORA_TARGET_MODULES)
  
 llama_config = LlamaConfig(max_position_embeddings=max_length)

model = LlamaForSequenceClassification.from_pretrained(model_path,
                                                       config=llama_config,
                                                       quantization_config=quant_config,
                                                       ignore_mismatched_sizes=True)

Expected behavior

Returns the following error

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized because the shapes did not match
@ArthurZucker
Copy link
Collaborator

Hey! Pretty sure that is expected: there are no pretrained checkpoints on sequence classification no?
THis does look like a warning not an error

@parasurama
Copy link
Author

You're right -- this was a warning not an error.

The warning is followed by the following error.

ValueError: weight is on the meta device, we need a `value` to put in on 0.

It looks like there are similar open issues. Something to do with tied weights not being detected correctly

huggingface/accelerate#2059

@ArthurZucker
Copy link
Collaborator

cc @SunMarc

@SunMarc
Copy link
Member

SunMarc commented May 3, 2024

Hi @parasurama, thanks for reporting ! I'll have a look asap

@SunMarc
Copy link
Member

SunMarc commented May 6, 2024

Hi @parasurama, this happens because you changed max_position_embeddings attribute. This modified a lot of weights and the whole model needs to be retrained. For now, we don't support loading mismatched weights with device_map="auto". I will work on adding that feature but this is not very useful since device_map should be used for inference mainly and the modified model needs retraining.

@parasurama
Copy link
Author

parasurama commented May 6, 2024

@SunMarc the error is raised even if I use the default max_position_embeddings.
The error is not raised when using a LLama2 model.

Here's a minimum example:

from transformers import (
    LlamaForSequenceClassification,
    LlamaConfig,
    BitsAndBytesConfig,
)
from peft import LoraConfig, TaskType
import torch
    
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, 
                            inference_mode=False, 
                            lora_alpha=16,
                            lora_dropout=0.05,
                            r=16,
                            bias="none",
                            target_modules=["q_proj", "v_proj"])
    
llama_config = LlamaConfig()

model = LlamaForSequenceClassification.from_pretrained("meta-llama/Meta-Llama-3-8B",
                                                        config=llama_config,
                                                        quantization_config=quant_config,
                                                        ignore_mismatched_sizes=True,
                                                        )

@SunMarc
Copy link
Member

SunMarc commented May 7, 2024

This happens because the default vocab_size of LlamaConfig is 32000 but llama v3 checkpoint have a vocab_size of 128256 but llama v2 checkpoint have a vocab_size of 32000. So by passing LlamaConfig() with "meta-llama/Meta-Llama-3-8B", you are modifying the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants