diff --git a/README.md b/README.md index c0d15015c..25ce24621 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ model=mistralai/Mistral-7B-Instruct-v0.1 volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/predibase/lorax:latest --model-id $model + ghcr.io/predibase/lorax:main --model-id $model ``` For a full tutorial including token streaming and the Python client, see [Getting Started - Docker](https://predibase.github.io/lorax/getting_started/docker). diff --git a/docs/getting_started/docker.md b/docs/getting_started/docker.md index 394238f46..29932ae96 100644 --- a/docs/getting_started/docker.md +++ b/docs/getting_started/docker.md @@ -11,8 +11,14 @@ model=mistralai/Mistral-7B-Instruct-v0.1 volume=$PWD/data # share a volume with the container as a weight cache docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/predibase/lorax:latest --model-id $model + ghcr.io/predibase/lorax:main --model-id $model ``` + +!!! note + + The `main` tag will use the image built from the HEAD of the main branch of the repo. For the latest stable image (built from a + tagged version) use the `latest` tag. + !!! note The LoRAX server in the pre-built Docker image is configured to listen on port 80 (instead of on the default port number, which is 3000). diff --git a/docs/getting_started/skypilot.md b/docs/getting_started/skypilot.md index a476d0df2..c495e326e 100644 --- a/docs/getting_started/skypilot.md +++ b/docs/getting_started/skypilot.md @@ -28,7 +28,7 @@ envs: run: | docker run --gpus all --shm-size 1g -p 8080:80 -v ~/data:/data \ - ghcr.io/predibase/lorax:latest \ + ghcr.io/predibase/lorax:main \ --model-id $MODEL_ID ``` diff --git a/docs/guides/contributing/development_env.md b/docs/guides/contributing/development_env.md index 1341b1b59..97955b44a 100644 --- a/docs/guides/contributing/development_env.md +++ b/docs/guides/contributing/development_env.md @@ -16,12 +16,12 @@ Pull and run the latest LoRAX docker image, mounting the directory containing yo # we will assume the lorax repo is found at ~/data/lorax volume=~/data -docker pull ghcr.io/predibase/lorax:latest +docker pull ghcr.io/predibase/lorax:main docker run \ --cap-add=SYS_PTRACE \ --gpus all --shm-size 1g \ -v $volume:/data \ - -itd --entrypoint /bin/bash ghcr.io/predibase/lorax:latest + -itd --entrypoint /bin/bash ghcr.io/predibase/lorax:main ``` !!! note diff --git a/docs/guides/speculative_decoding.md b/docs/guides/speculative_decoding.md new file mode 100644 index 000000000..2aed74614 --- /dev/null +++ b/docs/guides/speculative_decoding.md @@ -0,0 +1,66 @@ +# Speculative Decoding + +Speculative decoding describes a set of the methods for speeding up next token generation for autoregressive language models +by attempting to "guess" the next N tokens of the base model. These guesses can be generated in a number of different ways +including: + +- An addtional smaller "draft" model (e.g., Llama-70b and Llama-7b) +- An adapter that extends the sequence dimension of the logits (e.g., Medusa) +- A heuristic (e.g., looking for recurring sequences in the prompt) + +LoRAX implements some of these approaches, with a particular emphasis on supporting adapter-based methods like Medusa +that can be applied per request for task-level speedups. + +## Process + +Most all of the above speculative decoding methods consist of the same two phases: a "draft" phase that generates +candidate tokens and a "verification" phase that accepts some subset of the candidates to add to the response. + +### Draft + +For methods other than assisted generation via a draft model, the *draft step* happens at the end the normal next token +selection phase after generating the logits. Given the logits for the next token and all the tokens that have been +processed previously (input or output) a number of speculative tokens are generated and added to the batch state +for verification in the next inference step. + +### Verification + +Once the speculative logits have been generated, a separate *verification step* is performed whereby the most likely next `S` tokens +are passed through the model again (as part of the normal decoding process) to check for correctness. If any prefix of the `S` tokens +are deemed *correct*, then they can be appended to the response directly. The remaining incorrect speculative tokens are discarded. + +Note that this process adds some compute overhead to the normal decoding step. As such, it will only confer benefits when: + +1. The decoding step is *memory bound* (generally true for most LLMs on modern GPUs). +2. The speculation process is able to consistently predict future tokens correctly. + +## Options + +### Medusa + +See the [Medusa](../models/adapters/medusa.md) guide for details on how this method works and how to use it. + +### Prompt Lookup Decoding + +[Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file) is a simple +herustic method that uses string matching on the input + previously generated tokens to find candidate n-grams. This +method is particularly useful if your generation task will reuse many similar phrases from the input (e.g., in +retrieval augmented generation where citing the input is important). If there is no need to repeat anything from the +input, there will be no speedup and performance may decrease. + +#### Usage + +Initialize LoRAX with the `--speculative-tokens` param. This controls the length of the sequence LoRAX will attempt +to match against in the input and suggest as the continuation of the current token: + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \ + ghcr.io/predibase/lorax:main \ + --model-id mistralai/Mistral-7B-Instruct-v0.2 \ + --speculative-tokens 3 +``` + +Increasing this value will yield greater speedups when there are long common sequences, but slow things down if there +is little overlap. + +Note that this method is not compatible with Medusa adapters per request. diff --git a/docs/index.md b/docs/index.md index 305556374..ac877a4a7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -73,7 +73,7 @@ model=mistralai/Mistral-7B-Instruct-v0.1 volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/predibase/lorax:latest --model-id $model + ghcr.io/predibase/lorax:main --model-id $model ``` For a full tutorial including token streaming and the Python client, see [Getting Started - Docker](./getting_started/docker.md). diff --git a/docs/models/adapters.md b/docs/models/adapters/index.md similarity index 77% rename from docs/models/adapters.md rename to docs/models/adapters/index.md index 70008d589..e40ae005b 100644 --- a/docs/models/adapters.md +++ b/docs/models/adapters/index.md @@ -1,82 +1,40 @@ # Adapters -LoRAX currently supports LoRA adapters, which can be trained using frameworks like [PEFT](https://github.com/huggingface/peft) and [Ludwig](https://ludwig.ai/). - -## Target Modules - -Any combination of linear layers can be targeted in the adapters, which corresponds to the following target modules for each base model. - -### Llama - -- `q_proj` -- `k_proj` -- `v_proj` -- `o_proj` -- `gate_proj` -- `up_proj` -- `down_proj` -- `lm_head` - -### Mistral - -- `q_proj` -- `k_proj` -- `v_proj` -- `o_proj` -- `gate_proj` -- `up_proj` -- `down_proj` -- `lm_head` - -### Mixtral - -- `q_proj` -- `k_proj` -- `v_proj` -- `o_proj` -- `lm_head` - -### Gemma - -- `q_proj` -- `k_proj` -- `v_proj` -- `o_proj` -- `gate_proj` -- `up_proj` -- `down_proj` - -### Phi - -- `q_proj` -- `k_proj` -- `v_proj` -- `dense` -- `fc1` -- `fc2` -- `lm_head` - -### Qwen - -- `c_attn` -- `c_proj` -- `w1` -- `w2` -- `lm_head` - -### GPT2 - -- `c_attn` -- `c_proj` -- `c_fc` - -### Bloom - -- `query_key_value` -- `dense` -- `dense_h_to_4h` -- `dense_4h_to_h` -- `lm_head` +Adapters are small model fragments that can be loaded on top of base models in LoRAX -- either during server initialization +or at runtime as part of the request parameters. + +## Types + +### LoRA + +[LoRA](./lora.md) is a popular parameter efficient fine-tuning method to improve response quality. + +LoRAX can load any LoRA adapter dynamically at runtime per request, and batch many different LoRAs together at once +for high throughput. + +Usage: + +```json +"parameters": { + "adapter_id": "predibase/conllpp" +} +``` + +### Medusa + +[Medusa](./medusa.md) is a speculative decoding method that speeds up next-token generation by attempting to generate +more than one token at a time. + +LoRAX can load Medusa adapters dynamically at runtime per request provided that the LoRAX server was initialized with a +default Medusa adapter. + +Usage: + +```json +"parameters": { + "adapter_id": "predibase/Mistral-7B-Instruct-v0.2-magicoder-medusa" +} +``` ## Source diff --git a/docs/models/adapters/lora.md b/docs/models/adapters/lora.md new file mode 100644 index 000000000..1b4d0a662 --- /dev/null +++ b/docs/models/adapters/lora.md @@ -0,0 +1,154 @@ +# LoRA + +[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a popular adapter method for fine-tuning response quality. + +LoRAX supports LoRA adapters trained using frameworks like [PEFT](https://github.com/huggingface/peft) and [Ludwig](https://ludwig.ai/). + +## How it works + +``` mermaid +graph BT + I{{X}} --> W; + I --> A[/LoRA A\]; + A --> B[\LoRA B/]; + W --> P((+)); + B--> P; + P --> O{{Y}} +``` + +LoRA works by targeting specific layers of the base model and inserting a new low-rank pair of weights `LoRA A` and `LoRA B` alongside each base model +param `W`. The input `X` is passed through both the original weights and the LoRA weights, and then the activations are summed together +to produce the final layer output `Y`. + +## Usage + +### Supported Target Modules + +When training a LoRA adapter, you can specify which of these layers (or "modules") you wish to target for adaptation. Typically +these are the projection layers in the attention blocks (`q` and `v`, sometimes `k` and `o` as well for LLaMA like models), but can +usually be any linear layer. + +Here is a list of supported target modules for each architecture in LoRAX. Note that in cases where your adapter contains target +modules that LoRAX does not support, LoRAX will ignore those layers and emit a warning on the backend. + +#### Llama + +- `q_proj` +- `k_proj` +- `v_proj` +- `o_proj` +- `gate_proj` +- `up_proj` +- `down_proj` +- `lm_head` + +#### Mistral + +- `q_proj` +- `k_proj` +- `v_proj` +- `o_proj` +- `gate_proj` +- `up_proj` +- `down_proj` +- `lm_head` + +#### Mixtral + +- `q_proj` +- `k_proj` +- `v_proj` +- `o_proj` +- `lm_head` + +#### Gemma + +- `q_proj` +- `k_proj` +- `v_proj` +- `o_proj` +- `gate_proj` +- `up_proj` +- `down_proj` + +#### Phi-3 + +- `qkv_proj` +- `o_proj` +- `gate_up_proj` +- `down_proj` +- `lm_head` + +#### Phi-2 + +- `q_proj` +- `k_proj` +- `v_proj` +- `dense` +- `fc1` +- `fc2` +- `lm_head` + +#### Qwen2 + +- `q_proj` +- `k_proj` +- `v_proj` +- `o_proj` +- `gate_proj` +- `up_proj` +- `down_proj` +- `lm_head` + +#### Qwen + +- `c_attn` +- `c_proj` +- `w1` +- `w2` +- `lm_head` + +#### Command-R + +- `q_proj` +- `k_proj` +- `v_proj` +- `o_proj` +- `gate_proj` +- `up_proj` +- `down_proj` +- `lm_head` + +#### DBRX + +- `Wqkv` +- `out_proj` +- `lm_head` + +#### GPT2 + +- `c_attn` +- `c_proj` +- `c_fc` + +#### Bloom + +- `query_key_value` +- `dense` +- `dense_h_to_4h` +- `dense_4h_to_h` +- `lm_head` + +## How to train + +LoRA is a very popular fine-tuning method for LLMs, and as such there are a number of options for creating them +from your data, including the following (non-exhaustive) options. + +### Open Source + +- [PEFT](https://github.com/huggingface/peft) +- [Ludwig](https://ludwig.ai/) + +### Commercial + +- [Predibase](https://predibase.com/) diff --git a/docs/models/adapters/medusa.md b/docs/models/adapters/medusa.md new file mode 100644 index 000000000..380005024 --- /dev/null +++ b/docs/models/adapters/medusa.md @@ -0,0 +1,230 @@ +# Medusa + +[Medusa](https://arxiv.org/abs/2401.10774) is a [speculative decoding](../../guides/speculative_decoding.md) method +that trains new projection layers (similar to LoRA layers) for the purpose of predicting future tokens and speedng up +the text generation process. + +## How it works + +``` mermaid +graph BT + X{{H}} --> S((Stack)); + X --> M1[Medusa 1]; + X --> M2[Medusa 2]; + X --> M3[Medusa 3]; + M1 --> S; + M2 --> S; + M3 --> S; + S --> LM[LM Head] + LM --> L{{Logits}} +``` + +The goal of Medusa is to speed up text generation. Unlike LoRA, Medusa does not aim to improve response quality, and in +fact enabling Medusa will have no effect at all on the model output itself. Instead, Medusa works by adding additional +projections (called "medusa heads") that the last hidden state `H` of the LLM is passed through that attempt to predict +the next N tokens (rather than just the next 1 token). + +The result is that the output logit shape of the model at each decoding step is no longer `[B, 1, V]` for batch size `B` and vocabulary +size `V`, but instead `[B, S, V]` where `S` is the number of Medusa speculative heads `N` plus `1` for the original model +head. + +See the [Speculative Decoding](../../guides/speculative_decoding.md#verification) guide for more information on the verification +step that follows. + +### Change in v2 + +The original implementation of Medusa trained separate LM head layers for each Medusa head. This introduced significant +memory overhead that made dynamic loading of these adapters prohibitive. In v2, Medusa heads now reuse the base model +LM head, reducing memory overhead by an order of magnitude. + +LoRAX supports both v1 and v2 Medusa adapters, but only allows dynamic loading for v2 adapters. To see which version +your Medusa adapter is, check the `config.json` file for the `version` property. If not specified, the adapter is +assumed to be v1. + +## Usage + +### Initializing LoRAX with Medusa + +In order to use Medusa speculative decoding in LoRAX, you must initialize the LoRAX server with a valid Medusa adapter +as the "default" adapter. This means that by default every request will use the default Medusa adapter unless overriden +by the request parameters. + +Example: + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \ + ghcr.io/predibase/lorax:main \ + --model-id mistralai/Mistral-7B-Instruct-v0.2 \ + --adapter-id predibase/Mistral-7B-Instruct-v0.2-medusa +``` + +### Dynamic Medusa per Request + +When using a v2 Medusa adapter as default, you can also apply per-request Medusa adapters (that must also be v2) to +specialize the speculative decoding to the particular task. + +For example, you might have a general-purpose Medusa adapter as the default that improves throughput for most prompts +by ~50%. But if you know your incoming request is for code generation, you might want to apply a task-specific Medusa +adapter trained on only code generation examples for a ~100% speedup: + +=== "Python" + + ```python + from lorax import Client + + client = Client("http://127.0.0.1:8080") + prompt = "[INST] Write a Python function that takes a list of strings as input and returns a new list containing only the strings that are palindromes. [/INST]" + + resp = client.generate( + prompt, + adapter_id="predibase/Mistral-7B-Instruct-v0.2-magicoder-medusa", + ) + print(resp.generated_text) + ``` + +=== "REST" + + ```bash + curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{ + "inputs": "[INST] Write a Python function that takes a list of strings as input and returns a new list containing only the strings that are palindromes. [/INST]", + "parameters": { + "adapter_id": "predibase/Mistral-7B-Instruct-v0.2-magicoder-medusa" + } + }' \ + -H 'Content-Type: application/json' + ``` + +The one caveat to using per request Medusa adapters is that **adapters loaded per request must have the same number of +Medusa heads as the default Medusa adapter**. This is because for now the number of speculative tokens generated per +step is a constant defined during initialization. + +### Combining with LoRA + +When LoRAX has been initialized with a default Medusa, you may continue to use it with dynamic LoRA loading as usual: + + +=== "Python" + + ```python + from lorax import Client + + client = Client("http://127.0.0.1:8080") + prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" + + resp = client.generate( + prompt, + adapter_id="vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", + ) + print(resp.generated_text) + ``` + +=== "REST" + + ```bash + curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{ + "inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]", + "parameters": { + "adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k" + } + }' \ + -H 'Content-Type: application/json' + ``` + +The default Medusa adapter will be applied to every LoRA in the batch. In the future, we also plan to support LoRAs that +come with their own Medusa heads (Medusa 2). + +## How to train + +The official [Medusa GitHub repo](https://github.com/FasterDecoding/Medusa) contains recipes for training a Medusa v2 +adapter, including the self-distillation process. Broadly, the steps needed to create a Medusa adapter are: + +1. Prepare a dataset of example prompts in the ShareGPT or OpenAI conversation JSON format. +2. Generate responses from the base model you wish to adapt (Medusa 1). +3. Fine-tune Medusa heads using the prompt + response dataset. + +### Example + +Clone the repo (note: using a fork here that includes some fixes for more recent versions of `transformers`): + +```bash +git clone https://github.com/tgaddair/Medusa.git +cd Medusa +``` + +Install dependencies: + +```bash +pip install -e ".[train]" +pip install -U accelerate +``` + +Download the dataset: + +```bash +git lfs install +git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered +``` + +Launch a LoRAX server: + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \ + ghcr.io/predibase/lorax:main \ + --model-id mistralai/Mistral-7B-Instruct-v0.2 \ + --adapter-id predibase/Mistral-7B-Instruct-v0.2-medusa +``` + +Create the self-distillation dataset: + +```bash +python create_data.py \ + ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \ + /data/sharegpt-mistral-7b-instruct-02.json +``` + +Train: + +```bash +python medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ + --data_path /data/sharegpt-mistral-7b-instruct-02.json \ + --bf16 True \ + --output_dir sharegpt_mistral_7b_it_v02 \ + --num_train_epochs 3 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --evaluation_strategy "no" \ + --save_strategy "no" \ + --learning_rate 1e-3 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --lazy_preprocess True \ + --medusa_num_heads 3 \ + --medusa_num_layers 1 +``` + +Prompt with LoRAX: + +```bash +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{ + "inputs": "[INST] What is the photograph filter called where the only part of the image is greyscale? [/INST]", + "parameters": { + "max_new_tokens": 64, + "adapter_id": "/data/sharegpt_mistral_7b_it_v02", + "adapter_source": "local" + } + }' \ + -H 'Content-Type: application/json' +``` + +Next you can upload to HF and use as a base Medusa adapter or runtime Medusa adapter. diff --git a/docs/models/base_models.md b/docs/models/base_models.md index eed8be1bf..c37645980 100644 --- a/docs/models/base_models.md +++ b/docs/models/base_models.md @@ -8,8 +8,10 @@ - [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) - 🔄 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) - 💎 [Gemma](https://blog.google/technology/developers/gemma-open-models/) -- 🏛️ [Phi](https://huggingface.co/microsoft/phi-2) -- 🔮 [Qwen](https://huggingface.co/Qwen) +- 🏛️ [Phi-3](https://azure.microsoft.com/en-us/blog/introducing-phi-3-redefining-whats-possible-with-slms/) / [Phi-2](https://huggingface.co/microsoft/phi-2) +- 🔮 [Qwen2 / Qwen](https://huggingface.co/Qwen) +- 🗣️ [Command-R](https://docs.cohere.com/docs/command-r) +- 🧱 [DBRX](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) - 🤖 [GPT2](https://huggingface.co/gpt2) - 🌸 [Bloom](https://huggingface.co/bigscience/bloom) diff --git a/mkdocs.yml b/mkdocs.yml index 65821cf25..7f3b55c84 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -16,7 +16,7 @@ theme: # Palette toggle for light mode - scheme: default toggle: - icon: material/brightness-7 + icon: material/brightness-7 name: Switch to dark mode primary: deep orange accent: deep orange @@ -38,23 +38,27 @@ nav: - Local: getting_started/local.md - 🧮 Models: - Base Models: models/base_models.md - - Adapters: models/adapters.md + - Adapters: + - Adapters: models/adapters/index.md + - LoRA: models/adapters/lora.md + - Medusa: models/adapters/medusa.md - 📚 Reference: - Launcher: reference/launcher.md - REST API: reference/rest_api.md - Python Client: - - Python Client: reference/python_client/index.md - - lorax.client: reference/python_client/client.md + - Python Client: reference/python_client/index.md + - lorax.client: reference/python_client/client.md # - lorax.types: reference/python_client/types.md - OpenAI Compatible API: reference/openai_api.md - 🔬 Guides: - Quantization: guides/quantization.md - Structured Output (JSON): guides/structured_output.md + - Speculative Decoding: guides/speculative_decoding.md - CUDA Graph Compilation: guides/cuda_graphs.md - Merging Adapters: guides/merging_adapters.md - Contributing to LoRAX: - - Contributing to LoRAX: guides/contributing/index.md - - Development Environment: guides/contributing/development_env.md + - Contributing to LoRAX: guides/contributing/index.md + - Development Environment: guides/contributing/development_env.md # - GPUs: guides/gpus.md # - Fine-Tuning: guides/fine_tuning.md # - Memory Management: guides/memory_management.md @@ -67,7 +71,7 @@ copyright: Copyright © 2023 Predibase, Inc. extra: generator: false social: - - icon: fontawesome/brands/github + - icon: fontawesome/brands/github link: https://github.com/predibase/lorax - icon: fontawesome/brands/discord link: https://discord.gg/CBgdrGnZjy @@ -78,7 +82,12 @@ markdown_extensions: - pymdownx.superfences - pymdownx.tabbed: alternate_style: true + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format plugins: - render_swagger - - search \ No newline at end of file + - search