From 772a57540d33c98f3738f9d80788bead683ad9c9 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 2 May 2024 11:09:45 -0400 Subject: [PATCH] Integrate mistral.rs LLM (#13105) * Integrate * Changes based on comments * Run pants tailor * Properly extract and pass logprobs * Add a simple test * Add a usage section * Fix silly mistake * Add mistralrs as a dependancy * Fix extract logprobs and update api * Update for new version * Update version * Prettier * Fix typing * Remove unnecessary in docs BUILD --- .../llama-index-llms-mistral-rs/.gitignore | 153 ++++++++ .../llms/llama-index-llms-mistral-rs/BUILD | 3 + .../llms/llama-index-llms-mistral-rs/Makefile | 17 + .../llama-index-llms-mistral-rs/README.md | 68 ++++ .../examples/plain.ipynb | 69 ++++ .../examples/streaming.ipynb | 72 ++++ .../examples/xlora_gguf.ipynb | 70 ++++ .../llama_index/llms/mistral_rs/BUILD | 1 + .../llama_index/llms/mistral_rs/__init__.py | 3 + .../llama_index/llms/mistral_rs/base.py | 342 ++++++++++++++++++ .../pyproject.toml | 59 +++ .../llama-index-llms-mistral-rs/tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_llms_mistral-rs.py | 7 + 14 files changed, 865 insertions(+) create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/.gitignore create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/Makefile create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/README.md create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/plain.ipynb create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/streaming.ipynb create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/xlora_gguf.ipynb create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/base.py create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/pyproject.toml create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/test_llms_mistral-rs.py diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/.gitignore b/llama-index-integrations/llms/llama-index-llms-mistral-rs/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/BUILD b/llama-index-integrations/llms/llama-index-llms-mistral-rs/BUILD new file mode 100644 index 0000000000000..0896ca890d8bf --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/Makefile b/llama-index-integrations/llms/llama-index-llms-mistral-rs/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/README.md b/llama-index-integrations/llms/llama-index-llms-mistral-rs/README.md new file mode 100644 index 0000000000000..8cc03184eb9a9 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/README.md @@ -0,0 +1,68 @@ +# LlamaIndex Llms Integration: `mistral.rs` + +To use this integration, please install the Python `mistralrs` package: + +## Installation of `mistralrs` from PyPi + +0. Install Rust: https://rustup.rs/ + + ```bash + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + source $HOME/.cargo/env + ``` + +1. `mistralrs` depends on the `openssl` library. + +To install it on Ubuntu: + +``` +sudo apt install libssl-dev +sudo apt install pkg-config +``` + +2. Install it! + +- CUDA + + `pip install mistralrs-cuda` + +- Metal + + `pip install mistralrs-metal` + +- Apple Accelerate + + `pip install mistralrs-accelerate` + +- Intel MKL + + `pip install mistralrs-mkl` + +- Without accelerators + + `pip install mistralrs` + +All installations will install the `mistralrs` package. The suffix on the package installed by `pip` only controls the feature activation. + +## Installation from source + +Please follow the instructions [here](https://github.com/EricLBuehler/mistral.rs/blob/master/mistralrs-pyo3/README.md). + +## Usage + +```python +from llama_index.llms.mistral_rs import MistralRS +from mistralrs import Which + +llm = MistralRS( + which=Which.GGUF( + tok_model_id="mistralai/Mistral-7B-Instruct-v0.1", + quantized_model_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", + quantized_filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf", + tokenizer_json=None, + repeat_last_n=64, + ), + max_new_tokens=4096, + context_window=1024 * 5, +) +``` diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/plain.ipynb b/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/plain.ipynb new file mode 100644 index 0000000000000..0cf859364ae43 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/plain.ipynb @@ -0,0 +1,69 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings\n", + "from llama_index.core.embeddings import resolve_embed_model\n", + "from llama_index.llms.mistral_rs import MistralRS\n", + "from mistralrs import Which, Architecture\n", + "import sys\n", + "\n", + "documents = SimpleDirectoryReader(\"data\").load_data()\n", + "\n", + "# bge embedding model\n", + "Settings.embed_model = resolve_embed_model(\"local:BAAI/bge-small-en-v1.5\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Settings.llm = MistralRS(\n", + " which=Which.Plain(\n", + " model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n", + " arch=Architecture.Mistral,\n", + " tokenizer_json=None,\n", + " repeat_last_n=64,\n", + " ),\n", + " max_new_tokens=4096,\n", + " context_window=1024 * 5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index = VectorStoreIndex.from_documents(\n", + " documents,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine = index.as_query_engine()\n", + "response = query_engine.query(\"How do I pronounce graphene?\")\n", + "print(response)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/streaming.ipynb b/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/streaming.ipynb new file mode 100644 index 0000000000000..6bcfe155db049 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/streaming.ipynb @@ -0,0 +1,72 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings\n", + "from llama_index.core.embeddings import resolve_embed_model\n", + "from llama_index.llms.mistral_rs import MistralRS\n", + "from mistralrs import Which\n", + "import sys\n", + "\n", + "documents = SimpleDirectoryReader(\"data\").load_data()\n", + "\n", + "# bge embedding model\n", + "Settings.embed_model = resolve_embed_model(\"local:BAAI/bge-small-en-v1.5\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Settings.llm = MistralRS(\n", + " which=Which.GGUF(\n", + " tok_model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n", + " quantized_model_id=\"TheBloke/Mistral-7B-Instruct-v0.1-GGUF\",\n", + " quantized_filename=\"mistral-7b-instruct-v0.1.Q4_K_M.gguf\",\n", + " tokenizer_json=None,\n", + " repeat_last_n=64,\n", + " ),\n", + " max_new_tokens=4096,\n", + " context_window=1024 * 5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index = VectorStoreIndex.from_documents(\n", + " documents,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine = index.as_query_engine(streaming=True)\n", + "response = query_engine.query(\"What are non-granular scalings?\")\n", + "for text in response.response_gen:\n", + " print(text, end=\"\")\n", + " sys.stdout.flush()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/xlora_gguf.ipynb b/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/xlora_gguf.ipynb new file mode 100644 index 0000000000000..98dc78a41e652 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/examples/xlora_gguf.ipynb @@ -0,0 +1,70 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings\n", + "from llama_index.core.embeddings import resolve_embed_model\n", + "from llama_index.llms.mistral_rs import MistralRS\n", + "from mistralrs import Which\n", + "import sys\n", + "\n", + "documents = SimpleDirectoryReader(\"data\").load_data()\n", + "\n", + "# bge embedding model\n", + "Settings.embed_model = resolve_embed_model(\"local:BAAI/bge-small-en-v1.5\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Settings.llm = MistralRS(\n", + " which=Which.GGUF(\n", + " tok_model_id=\"mistralai/Mistral-7B-Instruct-v0.1\",\n", + " quantized_model_id=\"TheBloke/Mistral-7B-Instruct-v0.1-GGUF\",\n", + " quantized_filename=\"mistral-7b-instruct-v0.1.Q4_K_M.gguf\",\n", + " tokenizer_json=None,\n", + " repeat_last_n=64,\n", + " ),\n", + " max_new_tokens=4096,\n", + " context_window=1024 * 5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index = VectorStoreIndex.from_documents(\n", + " documents,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine = index.as_query_engine()\n", + "response = query_engine.query(\"How do I pronounce graphene?\")\n", + "print(response)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/BUILD b/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/__init__.py b/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/__init__.py new file mode 100644 index 0000000000000..13b51add65804 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/__init__.py @@ -0,0 +1,3 @@ +from llama_index.llms.mistral_rs.base import MistralRS + +__all__ = ["MistralRS"] diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/base.py b/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/base.py new file mode 100644 index 0000000000000..e42d41a127869 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/llama_index/llms/mistral_rs/base.py @@ -0,0 +1,342 @@ +from typing import Any, Callable, Dict, Optional, Sequence, List + +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, + MessageRole, + LogProb, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CallbackManager +from llama_index.core.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, + DEFAULT_TEMPERATURE, +) +from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback +from llama_index.core.llms.custom import CustomLLM +from llama_index.core.types import BaseOutputParser, PydanticProgramMode + +from mistralrs import ( + ChatCompletionRequest, + Runner, + Which, +) + +DEFAULT_TOPK = 32 +DEFAULT_TOPP = 0.1 +DEFAULT_REPEAT_LAST_N = 64 +DEFAULT_MAX_SEQS = 16 +DEFAULT_PREFIX_CACHE_N = 16 + + +def llama_index_to_mistralrs_messages( + messages: Sequence[ChatMessage], +) -> List[Dict[str, str]]: + """ + Convert llamaindex to mistralrs messages. Raises an exception if the role is not user or assistant. + """ + messages_new = [] + for message in messages: + if message.role == "user": + messages_new.append({"role": "user", "content": message.content}) + elif message.role == "assistant": + messages_new.append({"role": "assistant", "content": message.content}) + elif message.role == "system": + messages_new.append({"role": "system", "content": message.content}) + else: + raise ValueError( + f"Unsupported chat role `{message.role}` for `mistralrs` automatic chat templating: supported are `user`, `assistant`, `system`. Please specify `messages_to_prompt`." + ) + return messages_new + + +def extract_logprobs_choice(choice) -> Optional[List[LogProb]]: + if choice.logprobs is not None: + logprobs = [] + for logprob in choice.logprobs.content: + logprobs.append( + LogProb( + logprob=logprob.logprob, + bytes=logprob.bytes, + token=logprob.token, + ) + ) + else: + logprobs = None + return logprobs + + +def extract_logprobs(response) -> Optional[List[List[LogProb]]]: + if response.choices[0].logprobs is not None: + choice_logprobs = [] + for choice in response.choices: + choice_logprobs.append(extract_logprobs_choice(choice)) + else: + choice_logprobs = None + return choice_logprobs + + +def extract_logprobs_stream(response) -> Optional[List[List[LogProb]]]: + if response.choices[0].logprobs is not None: + logprobs = [extract_logprobs_choice(response.choices[0])] + else: + logprobs = None + return logprobs + + +class MistralRS(CustomLLM): + r"""MistralRS LLM. + + Examples: + Install `mistralrs` following instructions: + https://github.com/EricLBuehler/mistral.rs/blob/master/mistralrs-pyo3/README.md#installation-from-pypi + + Then `pip install llama-index-llms-mistral-rs` + + This LLM provides automatic chat templating as an option. If you do not provide `messages_to_prompt`, + mistral.rs will automatically determine one. You can specify a JINJA chat template by passing it in + `model_kwargs` in the `chat_template` key. + + ```python + from llama_index.llms.mistral_rs import MistralRS + from mistralrs import Which + + llm = MistralRS( + which = Which.XLora( + model_id=None, # Automatically determine from ordering file + tokenizer_json=None, + repeat_last_n=64, + xlora_model_id="lamm-mit/x-lora" + order="xlora-paper-ordering.json", # Make sure you copy the ordering file from `mistral.rs/orderings` + tgt_non_granular_index=None, + arch=Architecture.Mistral, + ), + temperature=0.1, + max_new_tokens=256, + context_window=3900, + generate_kwargs={}, + verbose=True, + ) + + response = llm.complete("Hello, how are you?") + print(str(response)) + ``` + """ + + model_url: Optional[str] = Field(description="local") + model_path: Optional[str] = Field(description="local") + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The maximum number of tokens to generate.", + gt=0, + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + generate_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Kwargs used for generation." + ) + model_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Kwargs used for model initialization." + ) + _runner: Runner = PrivateAttr("Mistral.rs model runner.") + _has_messages_to_prompt: bool = PrivateAttr("If `messages_to_prompt` is provided.") + + def __init__( + self, + which: Which, + temperature: float = DEFAULT_TEMPERATURE, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, + context_window: int = DEFAULT_CONTEXT_WINDOW, + top_k: int = DEFAULT_TOPK, + top_p: int = DEFAULT_TOPP, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + in_situ_quant: Optional[str] = None, + max_seqs: int = DEFAULT_MAX_SEQS, + token_source: str = "cache", + prefix_cache_n: str = DEFAULT_PREFIX_CACHE_N, + no_kv_cache: bool = False, + chat_template: Optional[str] = None, + top_logprobs: Optional[int] = None, + callback_manager: Optional[CallbackManager] = None, + generate_kwargs: Optional[Dict[str, Any]] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + ) -> None: + generate_kwargs = generate_kwargs or {} + generate_kwargs.update( + { + "temperature": temperature, + "max_tokens": max_new_tokens, + "top_k": top_k, + "top_p": top_p, + "top_logprobs": top_logprobs, + "logprobs": top_logprobs is not None, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + } + ) + + super().__init__( + model_path="local", + model_url="local", + temperature=temperature, + context_window=context_window, + max_new_tokens=max_new_tokens, + callback_manager=callback_manager, + generate_kwargs=generate_kwargs, + model_kwargs={}, + verbose=True, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + ) + + self._runner = Runner( + which=which, + token_source=token_source, + max_seqs=max_seqs, + prefix_cache_n=prefix_cache_n, + no_kv_cache=no_kv_cache, + chat_template=chat_template, + in_situ_quant=in_situ_quant, + ) + self._has_messages_to_prompt = messages_to_prompt is not None + + @classmethod + def class_name(cls) -> str: + return "MistralRS" + + @property + def metadata(self) -> LLMMetadata: + """LLM metadata.""" + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_new_tokens, + model_name=self.model_path, + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + if self._has_messages_to_prompt: + messages = self.messages_to_prompt(messages) + else: + messages = llama_index_to_mistralrs_messages(messages) + self.generate_kwargs.update({"stream": False}) + + request = ChatCompletionRequest( + messages=messages, + model="", + logit_bias=None, + **self.generate_kwargs, + ) + + response = self._runner.send_chat_completion_request(request) + return CompletionResponse( + text=response.choices[0].message.content, + logprobs=extract_logprobs(response), + ) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + if self._has_messages_to_prompt: + messages = self.messages_to_prompt(messages) + else: + messages = llama_index_to_mistralrs_messages(messages) + self.generate_kwargs.update({"stream": True}) + + request = ChatCompletionRequest( + messages=messages, + model="", + logit_bias=None, + **self.generate_kwargs, + ) + + streamer = self._runner.send_chat_completion_request(request) + + def gen() -> CompletionResponseGen: + text = "" + for response in streamer: + delta = response.choices[0].delta.content + text += delta + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=delta, + ), + delta=delta, + logprobs=extract_logprobs_stream(response), + ) + + return gen() + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + self.generate_kwargs.update({"stream": False}) + if not formatted: + prompt = self.completion_to_prompt(prompt) + + request = ChatCompletionRequest( + messages=prompt, + model="", + logit_bias=None, + **self.generate_kwargs, + ) + completion_response = self._runner.send_chat_completion_request(request) + return CompletionResponse( + text=completion_response.choices[0].message.content, + logprobs=extract_logprobs(completion_response), + ) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + self.generate_kwargs.update({"stream": True}) + if not formatted: + prompt = self.completion_to_prompt(prompt) + + request = ChatCompletionRequest( + messages=prompt, + model="", + logit_bias=None, + **self.generate_kwargs, + ) + + streamer = self._runner.send_chat_completion_request(request) + + def gen() -> CompletionResponseGen: + text = "" + for response in streamer: + delta = response.choices[0].delta.content + text += delta + yield CompletionResponse( + delta=delta, + text=text, + logprobs=extract_logprobs_stream(response), + ) + + return gen() diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mistral-rs/pyproject.toml new file mode 100644 index 0000000000000..bb829cf8c4ef9 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.mistral_rs" + +[tool.llamahub.class_authors] +MistralRs = "EricLBuehler" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["EricLBuehler"] +description = "llama-index llms mistral-rs integration" +exclude = ["**/BUILD"] +license = "MIT" +maintainers = ["jerryjliu"] +name = "llama-index-llms-mistral-rs" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" +mistralrs = "^0.1.3" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/__init__.py b/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/test_llms_mistral-rs.py b/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/test_llms_mistral-rs.py new file mode 100644 index 0000000000000..b9905ff2ddcf4 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistral-rs/tests/test_llms_mistral-rs.py @@ -0,0 +1,7 @@ +from llama_index.core.base.llms.base import BaseLLM +from llama_index.llms.mistral_rs import MistralRS + + +def test_embedding_class(): + names_of_base_classes = [b.__name__ for b in MistralRS.__mro__] + assert BaseLLM.__name__ in names_of_base_classes