Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lazy loading the lora module for reducing the loading p… #434

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.adapter import load_module_weight
from lorax_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
Expand Down Expand Up @@ -166,10 +167,10 @@ def load(
return None

lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_a = load_module_weight(lora_a_name, lora_a, base_device, model.dtype)

lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
lora_b = load_module_weight(lora_b_name, lora_b, base_device, model.dtype)

scale = get_scaling_factor(
config.lora_alpha,
Expand Down
31 changes: 29 additions & 2 deletions server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple
from typing import TYPE_CHECKING, Set, Tuple, Union

import torch
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer

Expand Down Expand Up @@ -75,6 +77,7 @@ def _load_and_merge(
weight_names,
api_token,
trust_remote_code,
False,
)

adapters_to_merge.append((module_map, adapter_config))
Expand Down Expand Up @@ -133,6 +136,7 @@ def load_module_map(
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
lazy_load_weights: bool = True,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
# TODO(geoffrey): refactor this and merge parts of this function with
# lorax_server/utils/adapter.py::create_merged_weight_files
Expand All @@ -154,8 +158,31 @@ def load_module_map(
adapter_filenames = source.weight_files()
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
if lazy_load_weights:
result = {}
# just fetching the layer names of the module
with safe_open(filename, framework="pt") as f:
for k in f.keys():
result[k] = filename
adapter_weights.update(result)
else:
adapter_weights.update(load_file(filename))

# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer


def load_module_weight(name: str, module: Union[torch.Tensor, str], device, dtype):
if isinstance(module, torch.Tensor):
return module.to(device, dtype)

if isinstance(device, torch.device):
if device.type == "cuda":
device = device.index
elif device.type == "cpu":
device = "cpu"

# module would be just the filename if lazy loading happened before
with safe_open(module, framework="pt", device=device) as f:
thincal marked this conversation as resolved.
Show resolved Hide resolved
return f.get_tensor(name).to(dtype)
6 changes: 3 additions & 3 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from loguru import logger
from safetensors import safe_open

from lorax_server.utils.adapter import load_module_weight


class AbstractWeights(ABC):
@abstractmethod
Expand Down Expand Up @@ -240,9 +242,7 @@ def get_slice(self, tensor_name: str) -> torch.Tensor:

def get_tensor(self, tensor_name: str) -> torch.Tensor:
tensor = self.weights[tensor_name]
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
return load_module_weight(tensor_name, tensor, self.device, self.dtype)

def get_slice_shape(self, slice) -> torch.Size:
return slice.shape
Expand Down