From e01df4a6af7d742a03f7ab6fd4f78952234a37f8 Mon Sep 17 00:00:00 2001 From: LS Date: Tue, 23 Apr 2024 15:12:07 +0800 Subject: [PATCH 1/4] feat: support lazy loading the lora module for reducing the loading place --- server/lorax_server/adapters/lora.py | 5 +++-- server/lorax_server/utils/adapter.py | 18 ++++++++++++++++-- server/lorax_server/utils/weights.py | 6 +++--- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index eea5301e7..8ae25067e 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -9,6 +9,7 @@ from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights +from lorax_server.utils.adapter import load_module_weight from lorax_server.utils.sgmv import ( BGMV_MAX_RANK, MAX_RANK_CUSTOM, @@ -166,10 +167,10 @@ def load( return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, model.dtype) + lora_a = load_module_weight(lora_a, base_device, model.dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, model.dtype) + lora_b = load_module_weight(lora_b, base_device, model.dtype) scale = get_scaling_factor( config.lora_alpha, diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 97322154d..1e207df5c 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -1,8 +1,9 @@ import warnings from dataclasses import dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple, Union +import torch from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -75,6 +76,7 @@ def _load_and_merge( weight_names, api_token, trust_remote_code, + False, ) adapters_to_merge.append((module_map, adapter_config)) @@ -133,6 +135,7 @@ def load_module_map( weight_names: Tuple[str], api_token: str, trust_remote_code: bool = False, + lazy_load_weights: bool = True, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with # lorax_server/utils/adapter.py::create_merged_weight_files @@ -154,8 +157,19 @@ def load_module_map( adapter_filenames = source.weight_files() adapter_weights = {} for filename in adapter_filenames: - adapter_weights.update(load_file(filename)) + if lazy_load_weights: + adapter_weights.update(filename) + else: + adapter_weights.update(load_file(filename)) # map the model weights to the relevant adapter weights (LoRA A and B matrices) module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer + + +def load_module_weight(module: Union[torch.Tensor, str], device, dtype): + if isinstance(module, torch.Tensor): + return module.to(device, dtype) + + # module would be just the filename if lazy loading happened before + return load_file(module, device=device).to(dtype) diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 91ce2390a..3eda9326e 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -11,6 +11,8 @@ from loguru import logger from safetensors import safe_open +from lorax_server.utils.adapter import load_module_weight + class AbstractWeights(ABC): @abstractmethod @@ -240,9 +242,7 @@ def get_slice(self, tensor_name: str) -> torch.Tensor: def get_tensor(self, tensor_name: str) -> torch.Tensor: tensor = self.weights[tensor_name] - tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) - return tensor + return load_module_weight(tensor, self.device, self.dtype) def get_slice_shape(self, slice) -> torch.Size: return slice.shape From 66784f553a405d36d8696c2503c34612b7c77593 Mon Sep 17 00:00:00 2001 From: LS Date: Tue, 23 Apr 2024 21:36:10 +0800 Subject: [PATCH 2/4] fix: store the layer:filename pair in module_map for lazy loading --- server/lorax_server/adapters/lora.py | 4 ++-- server/lorax_server/utils/adapter.py | 15 ++++++++++++--- server/lorax_server/utils/weights.py | 2 +- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 8ae25067e..fb0ba7adb 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -167,10 +167,10 @@ def load( return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = load_module_weight(lora_a, base_device, model.dtype) + lora_a = load_module_weight(lora_a_name, lora_a, base_device, model.dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = load_module_weight(lora_b, base_device, model.dtype) + lora_b = load_module_weight(lora_b_name, lora_b, base_device, model.dtype) scale = get_scaling_factor( config.lora_alpha, diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 1e207df5c..f2543b469 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -158,7 +158,12 @@ def load_module_map( adapter_weights = {} for filename in adapter_filenames: if lazy_load_weights: - adapter_weights.update(filename) + result = {} + # just fetching the layer names of the module + with safe_open(filename, framework="pt") as f: + for k in f.keys(): + result[k] = filename + adapter_weights.update(result) else: adapter_weights.update(load_file(filename)) @@ -167,9 +172,13 @@ def load_module_map( return module_map, adapter_config, adapter_weight_names, adapter_tokenizer -def load_module_weight(module: Union[torch.Tensor, str], device, dtype): +def load_module_weight(name: str, module: Union[torch.Tensor, str], device, dtype): if isinstance(module, torch.Tensor): return module.to(device, dtype) + if isinstance(device, torch.device) and device.type == "cuda": + device = device.index + # module would be just the filename if lazy loading happened before - return load_file(module, device=device).to(dtype) + with safe_open(module, framework="pt", device=device) as f: + return f.get_tensor(name).to(dtype) \ No newline at end of file diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 3eda9326e..f7078ef74 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -242,7 +242,7 @@ def get_slice(self, tensor_name: str) -> torch.Tensor: def get_tensor(self, tensor_name: str) -> torch.Tensor: tensor = self.weights[tensor_name] - return load_module_weight(tensor, self.device, self.dtype) + return load_module_weight(tensor_name, tensor, self.device, self.dtype) def get_slice_shape(self, slice) -> torch.Size: return slice.shape From 6d91a88d189df06536d2d2aca00f2feb03c35549 Mon Sep 17 00:00:00 2001 From: LS Date: Tue, 23 Apr 2024 21:38:30 +0800 Subject: [PATCH 3/4] fix: add missing imports --- server/lorax_server/utils/adapter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index f2543b469..a6885580b 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Set, Tuple, Union import torch +from safetensors import safe_open from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -181,4 +182,4 @@ def load_module_weight(name: str, module: Union[torch.Tensor, str], device, dtyp # module would be just the filename if lazy loading happened before with safe_open(module, framework="pt", device=device) as f: - return f.get_tensor(name).to(dtype) \ No newline at end of file + return f.get_tensor(name).to(dtype) From 86085df00ae2c44203eb115fb82673d3ac7a4ae4 Mon Sep 17 00:00:00 2001 From: LS Date: Tue, 11 Jun 2024 09:33:16 +0800 Subject: [PATCH 4/4] fix: work with cpu device object --- server/lorax_server/utils/adapter.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index a6885580b..0af379fa6 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -177,8 +177,11 @@ def load_module_weight(name: str, module: Union[torch.Tensor, str], device, dtyp if isinstance(module, torch.Tensor): return module.to(device, dtype) - if isinstance(device, torch.device) and device.type == "cuda": - device = device.index + if isinstance(device, torch.device): + if device.type == "cuda": + device = device.index + elif device.type == "cpu": + device = "cpu" # module would be just the filename if lazy loading happened before with safe_open(module, framework="pt", device=device) as f: