Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NPU support for Llava #1446

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@

If you are not using Linux, do *NOT* proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md).

If you are using Ascend NPU, see instructions for [AscendNPU support](docs/AscendNPU_Support.md).

1. Clone this repository and navigate to LLaVA folder
```bash
git clone https://github.com/haotian-liu/LLaVA.git
Expand Down Expand Up @@ -180,7 +182,7 @@ flowchart BT
subgraph Demo Connections
direction BT
c<-->gws

mw7b<-->c
mw13b<-->c
lsglw13b<-->c
Expand Down Expand Up @@ -431,14 +433,14 @@ If you find LLaVA useful for your research and applications, please cite using t
}

@misc{liu2023improvedllava,
title={Improved Baselines with Visual Instruction Tuning},
title={Improved Baselines with Visual Instruction Tuning},
author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
publisher={arXiv:2310.03744},
year={2023},
}

@misc{liu2023llava,
title={Visual Instruction Tuning},
title={Visual Instruction Tuning},
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
publisher={NeurIPS},
year={2023},
Expand Down
83 changes: 83 additions & 0 deletions docs/AscendNPU_Support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Run Llava on AscendNPU



## Installation
1. Clone this repository and navigate to LLaVA folder
```bash
git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
```

2. Install Package
```Shell
conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip # enable PEP 660 support
pip install -e .
```

3. Install additional packages for training cases
```
pip install -e ".[train]"
```

4. Install Ascend Extension for PyTorch

You can follow this [guide](https://www.hiascend.com/document/detail/en/ModelZoo/pytorchframework/ptes/ptes_00001.html) to download and install the Ascend NPU Firmware, Ascend NPU Driver, and CANN. Afterwards, you need to install additional Python packages.
```shell
pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu #For X86
pip3 install torch==2.1.0 #For Aarch64
pip3 install accelerate==0.28.0 decorator==5.1.1 scipy==1.13.0 attrs==23.2.0 openpyxl
```
After installing the above Python packages,
You can follow this [README](https://github.com/Ascend/pytorch/blob/master/README.md) to install the torch_npu environment.
Then you can use Llava on Ascend NPU.




## Pretrain/Finetune Llava on AscendNPU
If you want to Pretrain/Finetune Llava on AscendNPU, you only need to make modifications to two lines in the Pretrain/Finetune shell script.

As shown below:
```shell
# Firstly, add environment variables to the system via the 'source' command.
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# Disable TF32 mode
--tf32 False
```
Here is [finetune shell](scripts/v1_5/finetune_npu.sh) example on AscendNPU


## Inference/Evaluate Llava on AscendNPU
If you want to perform inference/evaluation, a small modification to your shell script is all that's needed.


As shown below, you only need to add a 'source' command in your shell script,and the usage for inference remains the same.
```shell
# textvqa.sh
source /usr/local/Ascend/ascend-toolkit/set_env.sh #Add this
python -m llava.eval.model_vqa_loader \
--model-path liuhaotian/llava-v1.5-13b \
--question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \
--image-folder ./playground/data/eval/textvqa/train_images \
--answers-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl \
--temperature 0 \
--conv-mode vicuna_v1

python -m llava.eval.eval_textvqa \
--annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \
--result-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl

# inference.sh
source /usr/local/Ascend/ascend-toolkit/set_env.sh #Add this
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file "https://llava-vl.github.io/static/images/view.jpg" \

```
*NOTE:Ascend NPU doesn't support all quantization methods. If you encounter issues during inference, you can remove the quantization.*



10 changes: 7 additions & 3 deletions llava/eval/model_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import shortuuid

from llava.conversation import default_conversation
from llava.utils import disable_torch_init
from llava.utils import disable_torch_init, is_npu_available

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu

@torch.inference_mode()
def eval_model(model_name, questions_file, answers_file):
Expand All @@ -17,7 +20,8 @@ def eval_model(model_name, questions_file, answers_file):
model_name = os.path.expanduser(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16).cuda()
torch_dtype=torch.float16).to("npu" if is_npu_available() else "cuda")



ques_file = open(os.path.expanduser(questions_file), "r")
Expand All @@ -30,7 +34,7 @@ def eval_model(model_name, questions_file, answers_file):
conv.append_message(conv.roles[0], qs)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
input_ids = torch.as_tensor(inputs.input_ids).to("npu" if is_npu_available() else "cuda")
output_ids = model.generate(
input_ids,
do_sample=True,
Expand Down
9 changes: 6 additions & 3 deletions llava/eval/model_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.utils import disable_torch_init, is_npu_available
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from PIL import Image
import math

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu

def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
Expand Down Expand Up @@ -53,15 +56,15 @@ def eval_model(args):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to("npu" if is_npu_available() else "cuda")

image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
image_tensor = process_images([image], image_processor, model.config)[0]

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
images=image_tensor.unsqueeze(0).half().to("npu" if is_npu_available() else "cuda"),
image_sizes=[image.size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
Expand Down
9 changes: 6 additions & 3 deletions llava/eval/model_vqa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.utils import disable_torch_init, is_npu_available
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import math

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu

def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
Expand Down Expand Up @@ -99,12 +102,12 @@ def eval_model(args):
idx = line["question_id"]
cur_prompt = line["text"]

input_ids = input_ids.to(device='cuda', non_blocking=True)
input_ids = input_ids.to(device="npu" if is_npu_available() else "cuda", non_blocking=True)

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
images=image_tensor.to(dtype=torch.float16, device="npu" if is_npu_available() else "cuda",non_blocking=True),
image_sizes=image_sizes,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
Expand Down
9 changes: 6 additions & 3 deletions llava/eval/model_vqa_mmbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.utils import disable_torch_init, is_npu_available
from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path

from PIL import Image
import math

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu

all_options = ['A', 'B', 'C', 'D']

Expand Down Expand Up @@ -103,14 +106,14 @@ def eval_model(args):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device="npu" if is_npu_available() else "cuda")

image_tensor = process_images([image], image_processor, model.config)[0]

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
images=image_tensor.unsqueeze(0).half().to(device="npu" if is_npu_available() else "cuda"),
image_sizes=[image.size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
Expand Down
9 changes: 6 additions & 3 deletions llava/eval/model_vqa_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.utils import disable_torch_init, is_npu_available
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from PIL import Image
import math

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu

def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
Expand Down Expand Up @@ -48,7 +51,7 @@ def eval_model(args):
image_file = line["image"]
image = Image.open(os.path.join(args.image_folder, image_file))
image_tensor = process_images([image], image_processor, model.config)[0]
images = image_tensor.unsqueeze(0).half().cuda()
images = image_tensor.unsqueeze(0).half().to(device="npu" if is_npu_available() else "cuda")
image_sizes = [image.size]
if getattr(model.config, 'mm_use_im_start_end', False):
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
Expand All @@ -68,7 +71,7 @@ def eval_model(args):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device="npu" if is_npu_available() else "cuda")

with torch.inference_mode():
output_ids = model.generate(
Expand Down
7 changes: 5 additions & 2 deletions llava/eval/run_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.utils import disable_torch_init, is_npu_available
from llava.mm_utils import (
process_images,
tokenizer_image_token,
Expand All @@ -24,6 +24,9 @@
from io import BytesIO
import re

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu

def image_parser(args):
out = args.image_file.split(args.sep)
Expand Down Expand Up @@ -108,7 +111,7 @@ def eval_model(args):
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
.to(device="npu" if is_npu_available() else "cuda")
)

with torch.inference_mode():
Expand Down
2 changes: 1 addition & 1 deletion llava/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(args):
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
Expand Down