Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when Using 8-bit Quantization #1616

Open
JhonDan1999 opened this issue May 3, 2024 · 5 comments
Open

Error when Using 8-bit Quantization #1616

JhonDan1999 opened this issue May 3, 2024 · 5 comments

Comments

@JhonDan1999
Copy link

JhonDan1999 commented May 3, 2024

I am encountering a data type mismatch error when using 8-bit quantization with the PEFT library and SFTTrainer for fine-tuning a language model. The error occurs during the generation phase after loading the fine-tuned model.

Here's an overview of my workflow:

  1. I fine-tuned a base model using the SFTTrainer from the TextRL library.
  2. After fine-tuning, I saved the adapter using PEFT.
  3. I loaded the fine-tuned model using PEFT and the BitsAndBytesConfig for 8-bit quantization.
  4. I merged the adapter with the base model using merge_and_unload().
  5. During the generation phase, I encountered a data type mismatch error.

Here's the code snippet for loading the fine-tuned model:

from peft import PeftConfig
config = PeftConfig.from_pretrained(peft_model_output_dir)

from transformers import AutoConfig, AutoModelForCausalLM

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4"
)

base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    quantization_config=bnb_config,
    device_map='auto',
)

ft_model = PeftModel.from_pretrained(base_model, peft_model_output_dir)
ft_model = ft_model.merge_and_unload()

The error message I'm getting is:

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != signed char

I have also printed the data types of the model parameters, and they appear to be a mix of torch.float16 and torch.int8, which is expected when using 8-bit quantization.

I would appreciate any guidance why am facing this issue

NOTE: this is you did not appear to me when I load the model in 4bits but after the fine-tuning I want to use the model in 8bits to get better accuracy (please correct me if my hypothesis is not correct)

Copy link

github-actions bot commented Jun 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@JhonDan1999
Copy link
Author

this still needs to be addressed

@younesbelkada
Copy link
Collaborator

Hi
Thanks for the issue and apologies for the delay, what peft version are you using?

@JhonDan1999
Copy link
Author

Hi @younesbelkada it is
peft version: 0.11.1

@younesbelkada
Copy link
Collaborator

thanks! Can you print the model after merging it ? Alternatively can you share a model that we can look into on the Hub?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants