Skip to content

Commit

Permalink
Fix for the LM_HEAD issue (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar committed May 23, 2024
1 parent 2481e70 commit da90421
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 13 deletions.
21 changes: 20 additions & 1 deletion server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ def key(cls) -> str:

@classmethod
def load(
self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata, prefill: bool
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}
Expand Down Expand Up @@ -302,6 +306,17 @@ def load(
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)

if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1

rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
Expand All @@ -315,6 +330,10 @@ def load(
tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()]
Expand Down
20 changes: 15 additions & 5 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from lorax_server.adapters.types import LORA
from lorax_server.utils.lora import LM_HEAD


@dataclass
Expand Down Expand Up @@ -46,7 +47,11 @@ def key(cls) -> str:

@abstractclassmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass

Expand All @@ -72,7 +77,9 @@ def max_speculative_tokens(self) -> int:
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0

def get_data(self, meta: AdapterBatchMetadata, prefill: bool) -> Dict[str, BatchAdapterWeights]:
def get_data(
self, meta: AdapterBatchMetadata, prefill: bool, prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
Expand All @@ -81,7 +88,7 @@ def get_data(self, meta: AdapterBatchMetadata, prefill: bool) -> Dict[str, Batch

batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(adapter_weights, meta, prefill)
batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
Expand All @@ -98,13 +105,16 @@ class AdapterBatchData:

@staticmethod
def from_meta(
meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights], prefill: bool
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(meta, prefill)
data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)

def ranks(self) -> Set[int]:
Expand Down
7 changes: 6 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,12 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option
# Assign pointers to LoRA weights
# TODO(travis): don't update this if indices haven't changed
# Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous
adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.layer_to_adapter_weights, prefill=True)
adapter_data = AdapterBatchData.from_meta(
batch.adapter_meta,
self.layer_to_adapter_weights,
prefill=True,
prefill_head_indices=None,
)

logits, past = self.forward(
batch.input_ids,
Expand Down
7 changes: 6 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,12 @@ def generate_token(

# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(adapter_meta, self.layer_to_adapter_weights, prefill)
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices
)

out, speculative_logits = self._try_generate_token(batch, adapter_data)

Expand Down
7 changes: 4 additions & 3 deletions server/lorax_server/utils/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os
from abc import abstractmethod
from pathlib import Path
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

from lorax_server.adapters.config import AdapterConfig
if TYPE_CHECKING:
from lorax_server.adapters.config import AdapterConfig


def try_to_load_from_cache(repo_cache: Path, revision: Optional[str], filename: str) -> Optional[Path]:
Expand Down Expand Up @@ -128,7 +129,7 @@ def get_weight_bytes(self) -> int:

return total_size

def load_config(self) -> AdapterConfig:
def load_config(self) -> "AdapterConfig":
from lorax_server.adapters import load_adapter_config

config_path = self.download_file("config.json", ignore_errors=True)
Expand Down
2 changes: 1 addition & 1 deletion server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_batched_lora_weights(lora_ranks: List[int]):
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=True).get(LORA)
data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA)

assert len(data.lora_a) == 2
assert data.lora_a.keys() == meta.adapter_set
Expand Down
2 changes: 1 addition & 1 deletion server/tests/utils/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_deterministic_tokens_temperature_zero(default_causal_lm, default_causal
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

adapter_data = AdapterBatchData.from_meta(
batch.adapter_meta, default_causal_lm.layer_to_adapter_weights, prefill=True
batch.adapter_meta, default_causal_lm.layer_to_adapter_weights, prefill=True, prefill_head_indices=None
)

logits, _ = default_causal_lm.forward(
Expand Down

0 comments on commit da90421

Please sign in to comment.