Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama3 and distributed checkpoint support in NeVA #9101

Merged
merged 67 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
f2f267d
temp save
yaoyu-33 Mar 16, 2024
b532d1b
Merge branch 'main' into yuya/neva_seq_pack
yaoyu-33 Mar 18, 2024
b0ace6b
temp save 2
yaoyu-33 Mar 18, 2024
0020fe3
update code
yaoyu-33 Mar 19, 2024
76fb748
Merge branch 'main' into yuya/neva_seq_pack
yaoyu-33 Mar 19, 2024
a8f2248
enable seq packing
yaoyu-33 Mar 19, 2024
9fab5a5
fix neva and clip
yaoyu-33 Mar 18, 2024
d8474fb
Enable parallel seq packing algo and few other fixes
yaoyu-33 Mar 21, 2024
c56ec9b
Merge branch 'main' into yuya/neva_seq_pack
yaoyu-33 Mar 22, 2024
e8a9a6d
Pipeline parallel support
yaoyu-33 Mar 25, 2024
c5ffa83
Update data preprocess
yaoyu-33 Mar 25, 2024
e11e260
Merge branch 'main' into yuya/neva_pp_support
yaoyu-33 Apr 2, 2024
2bc5d66
fix few pp issues
yaoyu-33 Apr 2, 2024
4843e54
Merge branch 'yuya/neva_seq_pack' into yuya/neva_pp_support
yaoyu-33 Apr 4, 2024
78034ce
enable sequence packing w/ PP
yaoyu-33 Apr 4, 2024
8561e60
Fix cu_seqlens in inputs
yaoyu-33 Apr 5, 2024
2ac6b27
add assert
yaoyu-33 Apr 8, 2024
d138b1e
Merge branch 'main' into yuya/neva_pp_support
yaoyu-33 Apr 16, 2024
5e6994d
Depend on PP to decide whether do padding
yaoyu-33 Apr 17, 2024
0544758
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
6af32af
Add docstring
yaoyu-33 Apr 17, 2024
cd513b3
Merge remote-tracking branch 'origin/yuya/neva_pp_support' into yuya/…
yaoyu-33 Apr 17, 2024
3655f7d
Fix few evaluation issues
yaoyu-33 Apr 23, 2024
f54e565
Merge branch 'main' into yuya/neva_pp_support
yaoyu-33 Apr 24, 2024
4bb0313
Fix few PP evaluation issues
yaoyu-33 Apr 24, 2024
6efa4fa
Address comments
yaoyu-33 Apr 24, 2024
9c44e30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
37953d4
add llama3 template
yaoyu-33 Apr 25, 2024
56700e7
address comments
yaoyu-33 Apr 25, 2024
b5a4c27
Fix license
yaoyu-33 Apr 26, 2024
f2588a0
Merge branch 'main' into yuya/neva_pp_support
yaoyu-33 Apr 26, 2024
cd5f5f6
Merge branch 'yuya/neva_pp_support' into yuya/neva_llama3
yaoyu-33 Apr 26, 2024
8c73b6f
Fix llama3
yaoyu-33 Apr 26, 2024
4855365
Few fixes
yaoyu-33 Apr 26, 2024
45a6c61
Merge branch 'main' into yuya/neva_pp_support
yaoyu-33 Apr 29, 2024
2d58b72
Merge branch 'yuya/neva_pp_support' into yuya/neva_llama3
yaoyu-33 Apr 29, 2024
f43df1c
Few neva bugs
yaoyu-33 Apr 30, 2024
8d9f8c1
Few neva bugs
yaoyu-33 Apr 30, 2024
449a42f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2024
0380c3b
Few neva bugs
yaoyu-33 Apr 30, 2024
9cb154d
Merge branch 'yuya/neva_pp_support' into yuya/neva_llama3
yaoyu-33 Apr 30, 2024
3e8c0eb
llama3 inference fix
yaoyu-33 May 1, 2024
9d2160d
Force vision encoder to run in fp32
yaoyu-33 May 2, 2024
17173a3
Revert "Force vision encoder to run in fp32"
yaoyu-33 May 2, 2024
24f2421
Merge branch 'main' into yuya/neva_llama3
yaoyu-33 May 2, 2024
07b0721
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2024
48d298a
Try adding distributed format of checkpoint
yaoyu-33 May 3, 2024
f46a5f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 3, 2024
9f6d98e
Allow dist checkpoint to be non-strict
yaoyu-33 May 3, 2024
3c4b5a1
Fix
yaoyu-33 May 8, 2024
3dde1eb
Some fixes for PP + dist ckpt in Neva
yaoyu-33 May 8, 2024
bbf0832
fix peft
yaoyu-33 May 10, 2024
9f5d615
few fixes for lora
yaoyu-33 May 10, 2024
cb3ae62
Merge branch 'main' into yuya/neva_llama3
yaoyu-33 May 10, 2024
4588b6a
checkpoint updates
yaoyu-33 May 10, 2024
725e353
Apply isort and black reformatting
yaoyu-33 May 10, 2024
7c38044
Merge branch 'main' into yuya/neva_llama3
yaoyu-33 May 13, 2024
9bd4929
Merge remote-tracking branch 'origin/yuya/neva_llama3' into yuya/neva…
yaoyu-33 May 13, 2024
34bff23
bug fix
yaoyu-33 May 13, 2024
bcf6385
Add neva dist checkpoint converter
yaoyu-33 May 13, 2024
24785dd
Apply isort and black reformatting
yaoyu-33 May 13, 2024
08dbfa5
resolve comments
yaoyu-33 May 14, 2024
968a75c
update neva dist ckpt apis
yaoyu-33 May 16, 2024
0b93023
Apply isort and black reformatting
yaoyu-33 May 16, 2024
5e5ad9e
fix return
yaoyu-33 May 16, 2024
8dc1142
Merge branch 'main' into yuya/neva_llama3
yaoyu-33 May 17, 2024
5450967
Merge branch 'main' into yuya/neva_llama3
yaoyu-33 May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from omegaconf import OmegaConf

from nemo.collections.multimodal.parts.utils import create_neva_model_and_processor
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam

CFG_STRING = """
trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def eval_model(args):
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v0")
parser.add_argument("--conv-mode", type=str, default="llava_v0") # this flag has no use!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then should we get rid of it..?

parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--num-chunks", type=int, default=1)
Expand Down
61 changes: 56 additions & 5 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from collections import defaultdict
from enum import Enum, auto
from typing import List

Expand All @@ -24,9 +25,14 @@
DEFAULT_SYSTEM_TOKEN = "<extra_id_0>"
DEFAULT_SEPARATOR_TOKEN = "<extra_id_1>"
DEFAULT_LABELS_TOKEN = "<extra_id_2>"
DEFAULT_IMAGE_PATCH_TOKEN = "<extra_id_3>"
DEFAULT_IM_START_TOKEN = "<extra_id_4>"
DEFAULT_IM_END_TOKEN = "<extra_id_5>"
DEFAULT_IMAGE_PATCH_TOKEN = defaultdict(lambda: "<extra_id_3>")
DEFAULT_IM_START_TOKEN = defaultdict(lambda: "<extra_id_4>")
DEFAULT_IM_END_TOKEN = defaultdict(lambda: "<extra_id_5>")

# Update llama3 default
DEFAULT_IMAGE_PATCH_TOKEN["llama_3"] = "<|reserved_special_token_3|>"
DEFAULT_IM_START_TOKEN["llama_3"] = "<|reserved_special_token_4|>"
DEFAULT_IM_END_TOKEN["llama_3"] = "<|reserved_special_token_5|>"


class SeparatorStyle(Enum):
Expand All @@ -36,6 +42,7 @@ class SeparatorStyle(Enum):
TWO = auto()
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
NVGPT = auto()


Expand Down Expand Up @@ -109,6 +116,34 @@ def get_prompt(self):
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.LLAMA_3:
"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>

{{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>

{{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}"
wrap_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>\n\n{msg}"
wrap_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>\n\n{msg}"

ret = "<|begin_of_text|>" + wrap_sys(self.system) + self.sep
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if type(message) is tuple:
message, _, _ = message
elif i % 2 == 0:
ret += wrap_user(message) + self.sep
else:
ret += wrap_assistant(message) + (self.sep if message else "")

elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
Expand Down Expand Up @@ -346,8 +381,25 @@ def dict(self):
sep2=DEFAULT_EOS_TOKEN,
)

conv_llava_llama_3 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("user", "assistant"),
version="llama_v3",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="<|eot_id|>",
)

conv_llava_plain = Conversation(
system="", roles=("", ""), messages=(), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n",
system="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)

conv_llava_v0 = Conversation(
Expand Down Expand Up @@ -416,6 +468,5 @@ def dict(self):
"nv_dpo": conv_nv_dpo,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())