Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor run_llava to avoid duplicated model loading if multiple evaluation is required #1160

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,23 @@ args = type('Args', (), {

eval_model(args)
```

You can use `get_model_and_processor` combined with `run_for_outputs` to avoid duplicated loading model

```
from types import SimpleNamespace

model_path = "liuhaotian/llava-v1.5-7b"
prompt = "What are the things I should be cautious about when I visit here?"
image_file = "https://llava-vl.github.io/static/images/view.jpg"

model, image_processor, tokenizer, p_conv_mode = get_model_and_processor(model_path)
cfig = {"model": model, "image_processor": image_processor, "tokenizer": tokenizer, "p_conv_mode": p_conv_mode}

args = SimpleNamespace(query=prompt, image_file=image_file,
sep=",", temperature=0, top_p=None, num_beams=1, max_new_tokens=512)
run_for_outputs(cfig, args)
```
</details>

## LLaVA Weights
Expand Down
76 changes: 43 additions & 33 deletions llava/eval/run_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,9 @@ def load_images(image_files):
out.append(image)
return out


def eval_model(args):
# Model
disable_torch_init()

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name
)

def run_for_outputs(cfig, args):
model, image_processor, tokenizer, p_conv_mode = cfig["model"], cfig["image_processor"], cfig["tokenizer"], cfig["p_conv_mode"]

qs = args.query
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
Expand All @@ -69,29 +62,7 @@ def eval_model(args):
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"

if args.conv_mode is not None and conv_mode != args.conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, args.conv_mode, args.conv_mode
)
)
else:
args.conv_mode = conv_mode

conv = conv_templates[args.conv_mode].copy()
conv = conv_templates[p_conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
Expand Down Expand Up @@ -126,7 +97,46 @@ def eval_model(args):

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)
return outputs # return for further processing

def get_model_and_processor(model_path, model_base=None, p_conv_mode=None):
# Model
disable_torch_init()

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, _ = load_pretrained_model(
model_path, model_base, model_name
)

if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"

if p_conv_mode is not None and conv_mode != p_conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, p_conv_mode, p_conv_mode
)
)
else:
p_conv_mode = conv_mode


return model, image_processor, tokenizer, p_conv_mode

def eval_model(args):
model, image_processor, tokenizer, p_conv_mode = get_model_and_processor(args.model_path, args.model_base, args.conv_mode)
cfig = {"model": model, "image_processor": image_processor, "tokenizer": tokenizer, "p_conv_mode": p_conv_mode}
return run_for_outputs(cfig, args)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
3 changes: 2 additions & 1 deletion llava/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
except:
except Exception as e:
print(f"Fail to import files: {e}")
pass