Skip to content

Commit

Permalink
Support jointly trained Medusa + LoRA adapters (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed May 22, 2024
1 parent 97ede52 commit a1ff52d
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 19 deletions.
23 changes: 17 additions & 6 deletions server/lorax_server/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
import json
from pathlib import Path
from typing import Optional
from typing import Dict, Optional

from lorax_server.adapters.config import AdapterConfig
from lorax_server.adapters.lora import LoraConfig
from lorax_server.adapters.medusa import MedusaConfig
from lorax_server.adapters.medusa_lora import MedusaLoraConfig
from lorax_server.adapters.weights import AdapterBatchData, AdapterBatchMetadata


def load_medusa_config(config_path: Optional[Path]) -> Optional[Dict]:
if config_path is not None and config_path.exists():
config = json.load(config_path.open())
if "medusa_num_heads" in config:
return config
return None


def load_adapter_config(
config_path: Optional[Path],
adapter_config_path: Optional[Path],
api_token: str,
) -> AdapterConfig:
medusa_config = load_medusa_config(config_path)
if adapter_config_path is not None and adapter_config_path.exists():
return LoraConfig.load(str(adapter_config_path.parent), api_token)
if medusa_config is not None:
return MedusaLoraConfig.load(str(adapter_config_path.parent), medusa_config, api_token)
else:
return LoraConfig.load(str(adapter_config_path.parent), api_token)

if config_path is not None and config_path.exists():
config = json.load(config_path.open())
if "medusa_num_heads" in config:
return MedusaConfig.load(config)
if medusa_config is not None:
return MedusaConfig.load(medusa_config)

raise ValueError(f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}")

Expand Down
17 changes: 13 additions & 4 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union

import torch
from peft import LoraConfig as _LoraConfig
Expand Down Expand Up @@ -138,8 +138,8 @@ def _transpose_weights(self):
self._is_transposed = not self._is_transposed

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
return BatchLoraWeights
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]

@classmethod
def load(
Expand Down Expand Up @@ -238,8 +238,11 @@ def key(cls) -> str:
@classmethod
def load(
self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata, prefill: bool
) -> "BatchLoraWeights":
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}
if not adapter_weights:
return None

first_weights = list(adapter_weights.values())[0]
device = first_weights.weights_a.device
Expand Down Expand Up @@ -347,3 +350,9 @@ def get_scaling_factor(
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r


def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v
18 changes: 14 additions & 4 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type

import torch
import torch.distributed
Expand Down Expand Up @@ -229,8 +229,8 @@ def __init__(self, config: MedusaConfig, module_map: ModuleMap, model: "Model"):
self.process_group = model.process_group

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
return BatchMedusaWeights
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchMedusaWeights]

@property
def speculative_tokens(self) -> int:
Expand Down Expand Up @@ -272,8 +272,12 @@ def __call__(self, x, lm_head):
@classmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool
) -> "BatchMedusaWeights":
) -> Optional["BatchMedusaWeights"]:
adapter_weights = {k: _convert_medusa(v) for k, v in adapter_weights.items()}
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)}
if not adapter_weights:
return None

default_medusa = adapter_weights.get(0)

segments = meta.adapter_segments
Expand Down Expand Up @@ -313,3 +317,9 @@ def load(
s_end=segments[[i + 1 for i in indices]],
),
)


def _convert_medusa(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "medusa_weights"):
return v.medusa_weights
return v
93 changes: 93 additions & 0 deletions server/lorax_server/adapters/medusa_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type

import torch

from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.lora import BatchLoraWeights, LoraConfig, LoraWeights
from lorax_server.adapters.medusa import BatchMedusaWeights, MedusaConfig, MedusaWeights
from lorax_server.adapters.weights import AdapterWeights, BatchAdapterWeights

if TYPE_CHECKING:
from lorax_server.models.model import Model

EMPTY_TENSOR = torch.tensor([])


@dataclass
class MedusaLoraModuleMap:
lora_module_map: ModuleMap
medusa_module_map: ModuleMap


@dataclass
class MedusaLoraConfig(AdapterConfig):
lora_config: LoraConfig
medusa_config: MedusaConfig

def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
) -> Tuple[MedusaLoraModuleMap, Set[str]]:
lora_module_map, weight_names = self.lora_config.map_weights_for_model(adapter_weights, weight_names)
medusa_module_map, _ = self.medusa_config.map_weights_for_model(adapter_weights, weight_names)
return MedusaLoraModuleMap(lora_module_map, medusa_module_map), weight_names

def load_batched_adapter_weights(
self,
model: "Model",
module_map: MedusaLoraModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
lora_weights = self.lora_config.load_batched_adapter_weights(
model, module_map.lora_module_map, layer_type, unused_weight_names, dynamic
)
medusa_weights = self.medusa_config.load_batched_adapter_weights(
model, module_map.medusa_module_map, layer_type, unused_weight_names, dynamic
)
return MedusaLoraWeights.load(
lora_weights,
medusa_weights,
)

@classmethod
def load(cls, adapter_id: str, config: dict, api_token: str) -> "MedusaLoraConfig":
lora_config = LoraConfig.load(adapter_id, api_token)
medusa_config = MedusaConfig.load(config)
return cls(
base_model_name_or_path=lora_config.base_model_name_or_path,
lora_config=lora_config,
medusa_config=medusa_config,
)


class MedusaLoraWeights(AdapterWeights):
def __init__(
self,
lora_weights: LoraWeights,
medusa_weights: MedusaWeights,
):
self.lora_weights = lora_weights
self.medusa_weights = medusa_weights

@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights, BatchMedusaWeights]

@property
def speculative_tokens(self) -> int:
return self.medusa_weights.speculative_tokens

@classmethod
def load(
cls,
lora_weights: LoraWeights,
medusa_weights: MedusaWeights,
) -> Optional[AdapterWeights]:
return MedusaLoraWeights(
lora_weights,
medusa_weights,
)
13 changes: 8 additions & 5 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Set, Type
from typing import Dict, List, Optional, Set, Type

import torch

Expand All @@ -27,7 +27,7 @@ class AdapterBatchMetadata:

class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_type(cls) -> "BatchAdapterWeights":
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass

@property
Expand All @@ -47,7 +47,7 @@ def key(cls) -> str:
@abstractclassmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool
) -> "BatchAdapterWeights":
) -> Optional["BatchAdapterWeights"]:
pass


Expand Down Expand Up @@ -76,11 +76,14 @@ def get_data(self, meta: AdapterBatchMetadata, prefill: bool) -> Dict[str, Batch
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
adapter_batch_types[adapter_weights.get_batch_type()][adapter_index] = adapter_weights
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights

batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta, prefill)
batched_weights = batch_type.load(adapter_weights, meta, prefill)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data


Expand Down

0 comments on commit a1ff52d

Please sign in to comment.