Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama3's prompt template to conversation.py #1443

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 45 additions & 0 deletions llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SeparatorStyle(Enum):
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -91,6 +92,36 @@ def get_prompt(self):
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.LLAMA_3:
wrap_sys = lambda msg: f"""<|start_header_id|>system<|end_header_id|>

{msg}<|eot_id|>"""
wrap_inst = lambda msg: f"""<|start_header_id|>user<|end_header_id|>

{msg}<|eot_id|>"""
wrap_resp = lambda msg: f"""<|start_header_id|>assistant<|end_header_id|>

{msg}<|eot_id|>"""
ret = "<|begin_of_text|>"

for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
ret += wrap_sys(self.system)
if i % 2 == 0:
message = wrap_inst(message)
ret += message
else:
message = wrap_resp(message)
ret += message
else:
ret += ""
ret += "<|start_header_id|>assistant<|end_header_id|>\n\n"
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
Expand Down Expand Up @@ -264,6 +295,19 @@ def dict(self):
sep2="</s>",
)

conv_llama_3 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v3",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="",
sep2="",
)

conv_llava_llama_2 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
Expand Down Expand Up @@ -376,6 +420,7 @@ def dict(self):
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"llama_3": conv_llama_3,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,
Expand Down