Skip to content

Commit

Permalink
Allow automatic chat templating
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 25, 2024
1 parent 6298d66 commit 3cd1ec3
Showing 1 changed file with 48 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
ChatCompletionRequest,
Runner,
Which,
Message,
Role,
)

DEFAULT_MISTRAL_RS_GGML_MODEL = (
Expand All @@ -45,6 +47,25 @@
DEFAULT_PREFIX_CACHE_N = 16


def llama_index_to_mistralrs_messages(messages: Sequence[ChatMessage]) -> list[Message]:
"""
Convert llamaindex to mistralrs messages. Raises an exception if the role is not user or assistant.
"""
messages_new = []
for message in messages:
if message.role == "user":
messages_new.append(Message(Role.User, message.content))
elif message.role == "assistant":
messages_new.append(Message(Role.Assistant, message.content))
elif message.role == "system":
messages_new.append(Message(Role.System, message.content))
else:
raise ValueError(
f"Unsupported chat role `{message.role}` for `mistralrs` automatic chat templating: supported are `user`, `assistant`, `system`. Please specify `messages_to_prompt`."
)
return messages_new


class MistralRS(CustomLLM):
r"""MistralRS LLM.
Expand All @@ -54,32 +75,14 @@ class MistralRS(CustomLLM):
Then `pip install llama-index-llms-mistral-rs`
This LLM provides automatic chat templating as an option. If you do not provide `messages_to_prompt`,
mistral.rs will automatically determine one. You can specify a JINJA chat template by passing it in
`model_kwargs` in the `chat_template` key.
```python
from llama_index.llms.mistral_rs import MistralRS
from mistralrs import Which
def messages_to_prompt(messages):
prompt = ""
for message in messages:
if message.role == 'system':
prompt += f"<|system|>\n{message.content}</s>\n"
elif message.role == 'user':
prompt += f"<|user|>\n{message.content}</s>\n"
elif message.role == 'assistant':
prompt += f"<|assistant|>\n{message.content}</s>\n"
# ensure we start with a system prompt, insert blank if needed
if not prompt.startswith("<|system|>\n"):
prompt = "<|system|>\n</s>\n" + prompt
# add final assistant prompt
prompt = prompt + "<|assistant|>\n"
return prompt
def completion_to_prompt(completion):
return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
llm = MistralRS(
which = Which.XLora(
model_id=None, # Automatically determine from ordering file
Expand All @@ -94,8 +97,6 @@ def completion_to_prompt(completion):
max_new_tokens=256,
context_window=3900,
generate_kwargs={},
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
verbose=True,
)
Expand Down Expand Up @@ -133,6 +134,7 @@ def completion_to_prompt(completion):
default_factory=dict, description="Kwargs used for model initialization."
)
_runner: Runner = PrivateAttr("Mistral.rs model runner.")
_has_messages_to_prompt: bool = PrivateAttr("If `messages_to_prompt` is provided.")

def __init__(
self,
Expand Down Expand Up @@ -188,6 +190,7 @@ def __init__(
no_kv_cache=model_kwargs.get("no_kv_cache", False),
chat_template=model_kwargs.get("chat_template", None),
)
self._has_messages_to_prompt = messages_to_prompt is not None

@classmethod
def class_name(cls) -> str:
Expand All @@ -204,19 +207,35 @@ def metadata(self) -> LLMMetadata:

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
if self._has_messages_to_prompt:
messages = self.messages_to_prompt(messages)
else:
messages = llama_index_to_mistralrs_messages(messages)
self.generate_kwargs.update({"stream": False})

request = ChatCompletionRequest(
messages=messages,
model="",
logit_bias=None,
logprobs=False,
**self.generate_kwargs,
)

response = self._runner.send_chat_completion_request(request)
return CompletionResponse(text=response.choices[0].message.content)

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
if self._has_messages_to_prompt:
messages = self.messages_to_prompt(messages)
else:
messages = llama_index_to_mistralrs_messages(messages)
self.generate_kwargs.update({"stream": True})

request = ChatCompletionRequest(
messages=prompt,
messages=messages,
model="",
logit_bias=None,
logprobs=False,
Expand Down

0 comments on commit 3cd1ec3

Please sign in to comment.