Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4-bit quant? #28

Open
nmandic78 opened this issue Mar 14, 2024 · 5 comments
Open

4-bit quant? #28

nmandic78 opened this issue Mar 14, 2024 · 5 comments

Comments

@nmandic78
Copy link

nmandic78 commented Mar 14, 2024

Hi! Thank you for releasing this multimodal model. First test are impressive. Even 1.3B is good for its size.
It is just that 7b version in full precision is still taxing on personal HW we have at home.
Would it be possible to quantize it to int4 like Qwen did with their Qwen-VL-Chat-Int4?
I think it would be best if you could do it and put it in your HF repo so community can use it.
If not, maybe you could give us some guidelines how to do it.

@jucamohedano
Copy link

jucamohedano commented Mar 19, 2024

Hey @nmandic78! I have applied 8-bit quantization to deepseek-vl-1.3b-chat which is the smallest model. I have a dummy repository where I throw all of my experiments. Unfortunately I haven't been able to push the model to the hub because serialization of the weights is not fully supported but I think you can do it in a different way. Here is a link to the notebook where I do the quantization, it's quite short.

@nmandic78
Copy link
Author

@jucamohedano, thank you for info. I'll take a look at quanto. On first glance I see they still have some issues with optimization ('latency: all models are at least 2x slower than the 16-bit models due to the lack of optimized kernels (for now).').

@RandomGitUser321
Copy link

RandomGitUser321 commented Mar 27, 2024

@nmandic78 You can edit a few things and use BitsAndBytesConfig from transformers and load it in 4bit mode. Since my GPU is a 2080 and doesn't support bfloat16, I had to edit all the other deepseek *.py files and change any bfloat16 -> float16. If you're on a 3XXX or 4XXX Nvidia card, you don't have to and can stick with bfloat16.

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images

#if you're on a 3XXX or 4XXX card, i think you can change the torch.float16 to torch.bfloat16
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)


# specify the path to the model
model_path = "DeepSeek-VL/deepseek-vl-7b-chat" #or "c:/local/path/to/deepseek-vl-7b-chat"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True, quantization_config=quantization_config, device_map="cuda")

conversation = [
    {
        "role": "User",
        "content": "<image_placeholder>Describe each stage of this image.",
        "images": ["./PATH/TO/YOUR/IMAGE/HERE.png"]
    },
    {
        "role": "Assistant",
        "content": ""
    }
]

# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
    conversations=conversation,
    images=pil_images,
    force_batchify=True
).to(vl_gpt.device)

# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

# run the model to get the response
outputs = vl_gpt.language_model.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=prepare_inputs.attention_mask,
    pad_token_id=tokenizer.eos_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=512,
    do_sample=False,
    use_cache=True
)

answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(answer)

This works just fine for me and stays within the 8GB limit of my GPU. The results are still accurate from what I can tell(you can ignore the flash attention warning).
4bit

@nmandic78
Copy link
Author

@RandomGitUser321, thank you!

@nmandic78
Copy link
Author

nmandic78 commented Apr 8, 2024

@RandomGitUser321 just short update, and if anyone else stumbles on this. I have 3090 so can do bfloat16, but I had to convert every layer .to(torch.bfloat16) to make it work. It was tedious following errors, but after converting them all it works just fine.
I have very limited knowledge of LLM architecture so I expect all this could be done more elegant on global level and not converting layer by layer.

My VRAM usage is higher (8.8GB, 370MB before model load). Ubuntu 22.04, NVIDIA 545.23
image

And one more observation. I tried with both FP4 and NF4 bnb_4bit_quant_type, and for same inputs (image + prompt) it looks FP4 answer is more detailed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants