Skip to content

Commit

Permalink
Add llama3 and distributed checkpoint support in NeVA (#9101)
Browse files Browse the repository at this point in the history
* temp save

Signed-off-by: yaoyu-33 <[email protected]>

* temp save 2

Signed-off-by: yaoyu-33 <[email protected]>

* update code

Signed-off-by: yaoyu-33 <[email protected]>

* enable seq packing

Signed-off-by: yaoyu-33 <[email protected]>

* fix neva and clip

Signed-off-by: yaoyu-33 <[email protected]>

* Enable parallel seq packing algo and few other fixes

Signed-off-by: yaoyu-33 <[email protected]>

* Pipeline parallel support

Signed-off-by: yaoyu-33 <[email protected]>

* Update data preprocess

Signed-off-by: yaoyu-33 <[email protected]>

* fix few pp issues

Signed-off-by: yaoyu-33 <[email protected]>

* enable sequence packing w/ PP

Signed-off-by: yaoyu-33 <[email protected]>

* Fix cu_seqlens in inputs

Signed-off-by: yaoyu-33 <[email protected]>

* add assert

Signed-off-by: yaoyu-33 <[email protected]>

* Depend on PP to decide whether do padding

Signed-off-by: yaoyu-33 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add docstring

Signed-off-by: yaoyu-33 <[email protected]>

* Fix few evaluation issues

Signed-off-by: yaoyu-33 <[email protected]>

* Fix few PP evaluation issues

Signed-off-by: yaoyu-33 <[email protected]>

* Address comments

Signed-off-by: yaoyu-33 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add llama3 template

Signed-off-by: yaoyu-33 <[email protected]>

* address comments

Signed-off-by: yaoyu-33 <[email protected]>

* Fix license

Signed-off-by: yaoyu-33 <[email protected]>

* Fix llama3

Signed-off-by: yaoyu-33 <[email protected]>

* Few fixes

Signed-off-by: yaoyu-33 <[email protected]>

* Few neva bugs

Signed-off-by: yaoyu-33 <[email protected]>

* Few neva bugs

Signed-off-by: yaoyu-33 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Few neva bugs

Signed-off-by: yaoyu-33 <[email protected]>

* llama3 inference fix

Signed-off-by: yaoyu-33 <[email protected]>

* Force vision encoder to run in fp32

Signed-off-by: yaoyu-33 <[email protected]>

* Revert "Force vision encoder to run in fp32"

This reverts commit 9d2160d.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Try adding distributed format of checkpoint

Signed-off-by: yaoyu-33 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Allow dist checkpoint to be non-strict

Signed-off-by: yaoyu-33 <[email protected]>

* Fix

Signed-off-by: yaoyu-33 <[email protected]>

* Some fixes for PP + dist ckpt in Neva

Signed-off-by: yaoyu-33 <[email protected]>

* fix peft

Signed-off-by: yaoyu-33 <[email protected]>

* few fixes for lora

Signed-off-by: yaoyu-33 <[email protected]>

* checkpoint updates

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* bug fix

Signed-off-by: yaoyu-33 <[email protected]>

* Add neva dist checkpoint converter

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* resolve comments

Signed-off-by: yaoyu-33 <[email protected]>

* update neva dist ckpt apis

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix return

Signed-off-by: yaoyu-33 <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: yaoyu-33 <[email protected]>
  • Loading branch information
3 people committed May 22, 2024
1 parent c7bf46e commit d7bb403
Show file tree
Hide file tree
Showing 14 changed files with 690 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ inference:
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
end_strings: ["<extra_id_1>","<extra_id_7>",] # generation will stop when one of these tokens is generated
media_base_path: /pwd/images # /path/to/images or /path/to/videos
insert_media_token: left # `left` or `right` or `null`
insert_media_token: null # `left` or `right` or `null`
media_type: image # `image` or `video`

trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from omegaconf import OmegaConf

from nemo.collections.multimodal.parts.utils import create_neva_model_and_processor
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam

CFG_STRING = """
trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def eval_model(args):
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v0")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--num-chunks", type=int, default=1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentParser
from omegaconf.omegaconf import OmegaConf

from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.utils import logging


def get_args():
parser = ArgumentParser()
parser.add_argument(
"--input_path",
type=str,
default=None,
required=True,
help="Path to NeMo legacy checkpoints",
)
parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--gpus_per_node", type=int, required=False, default=8)
parser.add_argument("--num_nodes", type=int, required=False, default=1)
parser.add_argument(
"--precision",
type=str,
required=False,
default='bf16-mixed',
choices=['32-true', '16-mixed', 'bf16-mixed'],
help="Precision value for the trainer that matches with precision of the ckpt",
)
args = parser.parse_args()
return args


def main() -> None:
args = get_args()
cfg = {
'trainer': {
'devices': args.gpus_per_node,
'num_nodes': args.num_nodes,
'accelerator': 'gpu',
'precision': args.precision,
},
'model': {
'native_amp_init_scale': 2**32,
'native_amp_growth_interval': 1000,
'hysteresis': 2,
'gradient_as_bucket_view': True,
},
'cluster_type': 'BCP',
}
cfg = OmegaConf.create(cfg)

# Set precision None after precision plugins are created as PTL >= 2.1 does not allow both
# precision plugins and precision to exist
cfg.trainer.precision = None
trainer = MegatronTrainerBuilder(cfg).create_trainer()

save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(args.input_path):
save_restore_connector.model_extracted_dir = args.input_path

model = MegatronNevaModel.restore_from(
restore_path=args.input_path,
trainer=trainer,
save_restore_connector=save_restore_connector,
strict=False,
)

model.save_to(args.output_path)
logging.info(f'NeMo model saved to: {args.output_path}')


if __name__ == '__main__':
main()
61 changes: 56 additions & 5 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from collections import defaultdict
from enum import Enum, auto
from typing import List

Expand All @@ -24,9 +25,14 @@
DEFAULT_SYSTEM_TOKEN = "<extra_id_0>"
DEFAULT_SEPARATOR_TOKEN = "<extra_id_1>"
DEFAULT_LABELS_TOKEN = "<extra_id_2>"
DEFAULT_IMAGE_PATCH_TOKEN = "<extra_id_3>"
DEFAULT_IM_START_TOKEN = "<extra_id_4>"
DEFAULT_IM_END_TOKEN = "<extra_id_5>"
DEFAULT_IMAGE_PATCH_TOKEN = defaultdict(lambda: "<extra_id_3>")
DEFAULT_IM_START_TOKEN = defaultdict(lambda: "<extra_id_4>")
DEFAULT_IM_END_TOKEN = defaultdict(lambda: "<extra_id_5>")

# Update llama3 default
DEFAULT_IMAGE_PATCH_TOKEN["llama_3"] = "<|reserved_special_token_3|>"
DEFAULT_IM_START_TOKEN["llama_3"] = "<|reserved_special_token_4|>"
DEFAULT_IM_END_TOKEN["llama_3"] = "<|reserved_special_token_5|>"


class SeparatorStyle(Enum):
Expand All @@ -36,6 +42,7 @@ class SeparatorStyle(Enum):
TWO = auto()
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
NVGPT = auto()


Expand Down Expand Up @@ -109,6 +116,34 @@ def get_prompt(self):
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.LLAMA_3:
"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}"
wrap_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>\n\n{msg}"
wrap_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>\n\n{msg}"

ret = "<|begin_of_text|>" + wrap_sys(self.system) + self.sep
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if type(message) is tuple:
message, _, _ = message
elif i % 2 == 0:
ret += wrap_user(message) + self.sep
else:
ret += wrap_assistant(message) + (self.sep if message else "")

elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
Expand Down Expand Up @@ -346,8 +381,25 @@ def dict(self):
sep2=DEFAULT_EOS_TOKEN,
)

conv_llava_llama_3 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("user", "assistant"),
version="llama_v3",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="<|eot_id|>",
)

conv_llava_plain = Conversation(
system="", roles=("", ""), messages=(), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n",
system="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)

conv_llava_v0 = Conversation(
Expand Down Expand Up @@ -416,6 +468,5 @@ def dict(self):
"nv_dpo": conv_nv_dpo,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())

0 comments on commit d7bb403

Please sign in to comment.