Skip to content

Commit

Permalink
Medusa docs (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed May 6, 2024
1 parent 35791ba commit c917ccd
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 93 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ model=mistralai/Mistral-7B-Instruct-v0.1
volume=$PWD/data

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/predibase/lorax:latest --model-id $model
ghcr.io/predibase/lorax:main --model-id $model
```

For a full tutorial including token streaming and the Python client, see [Getting Started - Docker](https://predibase.github.io/lorax/getting_started/docker).
Expand Down
8 changes: 7 additions & 1 deletion docs/getting_started/docker.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@ model=mistralai/Mistral-7B-Instruct-v0.1
volume=$PWD/data # share a volume with the container as a weight cache

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/predibase/lorax:latest --model-id $model
ghcr.io/predibase/lorax:main --model-id $model
```

!!! note

The `main` tag will use the image built from the HEAD of the main branch of the repo. For the latest stable image (built from a
tagged version) use the `latest` tag.

!!! note

The LoRAX server in the pre-built Docker image is configured to listen on port 80 (instead of on the default port number, which is 3000).
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started/skypilot.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ envs:

run: |
docker run --gpus all --shm-size 1g -p 8080:80 -v ~/data:/data \
ghcr.io/predibase/lorax:latest \
ghcr.io/predibase/lorax:main \
--model-id $MODEL_ID
```

Expand Down
4 changes: 2 additions & 2 deletions docs/guides/contributing/development_env.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ Pull and run the latest LoRAX docker image, mounting the directory containing yo
# we will assume the lorax repo is found at ~/data/lorax
volume=~/data

docker pull ghcr.io/predibase/lorax:latest
docker pull ghcr.io/predibase/lorax:main
docker run \
--cap-add=SYS_PTRACE \
--gpus all --shm-size 1g \
-v $volume:/data \
-itd --entrypoint /bin/bash ghcr.io/predibase/lorax:latest
-itd --entrypoint /bin/bash ghcr.io/predibase/lorax:main
```

!!! note
Expand Down
66 changes: 66 additions & 0 deletions docs/guides/speculative_decoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Speculative Decoding

Speculative decoding describes a set of the methods for speeding up next token generation for autoregressive language models
by attempting to "guess" the next N tokens of the base model. These guesses can be generated in a number of different ways
including:

- An addtional smaller "draft" model (e.g., Llama-70b and Llama-7b)
- An adapter that extends the sequence dimension of the logits (e.g., Medusa)
- A heuristic (e.g., looking for recurring sequences in the prompt)

LoRAX implements some of these approaches, with a particular emphasis on supporting adapter-based methods like Medusa
that can be applied per request for task-level speedups.

## Process

Most all of the above speculative decoding methods consist of the same two phases: a "draft" phase that generates
candidate tokens and a "verification" phase that accepts some subset of the candidates to add to the response.

### Draft

For methods other than assisted generation via a draft model, the *draft step* happens at the end the normal next token
selection phase after generating the logits. Given the logits for the next token and all the tokens that have been
processed previously (input or output) a number of speculative tokens are generated and added to the batch state
for verification in the next inference step.

### Verification

Once the speculative logits have been generated, a separate *verification step* is performed whereby the most likely next `S` tokens
are passed through the model again (as part of the normal decoding process) to check for correctness. If any prefix of the `S` tokens
are deemed *correct*, then they can be appended to the response directly. The remaining incorrect speculative tokens are discarded.

Note that this process adds some compute overhead to the normal decoding step. As such, it will only confer benefits when:

1. The decoding step is *memory bound* (generally true for most LLMs on modern GPUs).
2. The speculation process is able to consistently predict future tokens correctly.

## Options

### Medusa

See the [Medusa](../models/adapters/medusa.md) guide for details on how this method works and how to use it.

### Prompt Lookup Decoding

[Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file) is a simple
herustic method that uses string matching on the input + previously generated tokens to find candidate n-grams. This
method is particularly useful if your generation task will reuse many similar phrases from the input (e.g., in
retrieval augmented generation where citing the input is important). If there is no need to repeat anything from the
input, there will be no speedup and performance may decrease.

#### Usage

Initialize LoRAX with the `--speculative-tokens` param. This controls the length of the sequence LoRAX will attempt
to match against in the input and suggest as the continuation of the current token:

```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \
ghcr.io/predibase/lorax:main \
--model-id mistralai/Mistral-7B-Instruct-v0.2 \
--speculative-tokens 3
```

Increasing this value will yield greater speedups when there are long common sequences, but slow things down if there
is little overlap.

Note that this method is not compatible with Medusa adapters per request.
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ model=mistralai/Mistral-7B-Instruct-v0.1
volume=$PWD/data

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/predibase/lorax:latest --model-id $model
ghcr.io/predibase/lorax:main --model-id $model
```

For a full tutorial including token streaming and the Python client, see [Getting Started - Docker](./getting_started/docker.md).
Expand Down
112 changes: 35 additions & 77 deletions docs/models/adapters.md → docs/models/adapters/index.md
Original file line number Diff line number Diff line change
@@ -1,82 +1,40 @@
# Adapters

LoRAX currently supports LoRA adapters, which can be trained using frameworks like [PEFT](https://github.com/huggingface/peft) and [Ludwig](https://ludwig.ai/).

## Target Modules

Any combination of linear layers can be targeted in the adapters, which corresponds to the following target modules for each base model.

### Llama

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`
- `lm_head`

### Mistral

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`
- `lm_head`

### Mixtral

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `lm_head`

### Gemma

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`

### Phi

- `q_proj`
- `k_proj`
- `v_proj`
- `dense`
- `fc1`
- `fc2`
- `lm_head`

### Qwen

- `c_attn`
- `c_proj`
- `w1`
- `w2`
- `lm_head`

### GPT2

- `c_attn`
- `c_proj`
- `c_fc`

### Bloom

- `query_key_value`
- `dense`
- `dense_h_to_4h`
- `dense_4h_to_h`
- `lm_head`
Adapters are small model fragments that can be loaded on top of base models in LoRAX -- either during server initialization
or at runtime as part of the request parameters.

## Types

### LoRA

[LoRA](./lora.md) is a popular parameter efficient fine-tuning method to improve response quality.

LoRAX can load any LoRA adapter dynamically at runtime per request, and batch many different LoRAs together at once
for high throughput.

Usage:

```json
"parameters": {
"adapter_id": "predibase/conllpp"
}
```

### Medusa

[Medusa](./medusa.md) is a speculative decoding method that speeds up next-token generation by attempting to generate
more than one token at a time.

LoRAX can load Medusa adapters dynamically at runtime per request provided that the LoRAX server was initialized with a
default Medusa adapter.

Usage:

```json
"parameters": {
"adapter_id": "predibase/Mistral-7B-Instruct-v0.2-magicoder-medusa"
}
```

## Source

Expand Down
154 changes: 154 additions & 0 deletions docs/models/adapters/lora.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# LoRA

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a popular adapter method for fine-tuning response quality.

LoRAX supports LoRA adapters trained using frameworks like [PEFT](https://github.com/huggingface/peft) and [Ludwig](https://ludwig.ai/).

## How it works

``` mermaid
graph BT
I{{X}} --> W;
I --> A[/LoRA A\];
A --> B[\LoRA B/];
W --> P((+));
B--> P;
P --> O{{Y}}
```

LoRA works by targeting specific layers of the base model and inserting a new low-rank pair of weights `LoRA A` and `LoRA B` alongside each base model
param `W`. The input `X` is passed through both the original weights and the LoRA weights, and then the activations are summed together
to produce the final layer output `Y`.

## Usage

### Supported Target Modules

When training a LoRA adapter, you can specify which of these layers (or "modules") you wish to target for adaptation. Typically
these are the projection layers in the attention blocks (`q` and `v`, sometimes `k` and `o` as well for LLaMA like models), but can
usually be any linear layer.

Here is a list of supported target modules for each architecture in LoRAX. Note that in cases where your adapter contains target
modules that LoRAX does not support, LoRAX will ignore those layers and emit a warning on the backend.

#### Llama

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`
- `lm_head`

#### Mistral

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`
- `lm_head`

#### Mixtral

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `lm_head`

#### Gemma

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`

#### Phi-3

- `qkv_proj`
- `o_proj`
- `gate_up_proj`
- `down_proj`
- `lm_head`

#### Phi-2

- `q_proj`
- `k_proj`
- `v_proj`
- `dense`
- `fc1`
- `fc2`
- `lm_head`

#### Qwen2

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`
- `lm_head`

#### Qwen

- `c_attn`
- `c_proj`
- `w1`
- `w2`
- `lm_head`

#### Command-R

- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`
- `lm_head`

#### DBRX

- `Wqkv`
- `out_proj`
- `lm_head`

#### GPT2

- `c_attn`
- `c_proj`
- `c_fc`

#### Bloom

- `query_key_value`
- `dense`
- `dense_h_to_4h`
- `dense_4h_to_h`
- `lm_head`

## How to train

LoRA is a very popular fine-tuning method for LLMs, and as such there are a number of options for creating them
from your data, including the following (non-exhaustive) options.

### Open Source

- [PEFT](https://github.com/huggingface/peft)
- [Ludwig](https://ludwig.ai/)

### Commercial

- [Predibase](https://predibase.com/)
Loading

0 comments on commit c917ccd

Please sign in to comment.