Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Peft update #557

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
cpufeature>=0.2.0; platform_machine == "x86_64"
packaging>=20.9
sentencepiece>=0.1.99
peft==0.5.0
peft==0.8.2
safetensors>=0.3.1
Dijkstar>=2.6.0

Expand Down
2 changes: 1 addition & 1 deletion src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def convert_block(
if adapters:
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft

create_lora_adapter(block, quant_type=quant_type)
create_lora_adapter(block)
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,
Expand Down
92 changes: 44 additions & 48 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import re
import time
from typing import Optional, Sequence, Union
from typing import List, Optional, Sequence, Union

import bitsandbytes as bnb
import torch
Expand All @@ -12,19 +12,21 @@
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.config import PeftConfig
from peft.tuners import lora
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo

from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes

logger = get_logger(__name__)


COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]


def check_peft_repository(repo_id: str) -> bool:
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")

Expand Down Expand Up @@ -151,67 +153,60 @@ def active_adapter(self):
def active_adapter(self, value: Optional[str]):
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""

@property
def active_adapters(self):
return [self._context_active_adapter]

def set_adapter(self, adapter_names) -> None:
"""
In PEFT, this function making adapter trainable. However, in Petals environment is not possible now. So,
this code remove this functionality.
artek0chumak marked this conversation as resolved.
Show resolved Hide resolved
Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463
"""
pass


using_adapter = AdapterContextMixin.using_adapter


class LoraLinear(AdapterContextMixin, lora.Linear):
"""LoRA linear layer that uses adapter selected via using_adapter"""

def __init__(self, base_layer, adapter_name: str):
nn.Module.__init__(self)
lora.LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.is_target_conv_1d_layer = False


class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
# TODO: Check if lora.Linear can be mixed with lora.Linear8bitLt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this checked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I checked that it worked, and the outputs were the same.

class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""


class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
# TODO: Check if lora.Linear can be mixed with lora.Linear4bit
class LoraLinear4bit(LoraLinear, lora.Linear4bit):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""


def create_lora_adapter(block, quant_type: QuantType):
for _, module in block.named_modules():
def create_lora_adapter(block):
for module_name, module in block.named_modules():
if isinstance(module, LoraLinear):
continue
for child_name, child in module.named_children():
lora_wrapped_child = None
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
continue
if quant_type == QuantType.INT8:
kwargs = {
"has_fp16_weights": False,
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = LoraLinear8bitLt(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
elif quant_type == QuantType.NF4:
kwargs = {
"compress_statistics": True,
"quant_type": "nf4",
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = LoraLinear4bit(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
lora_wrapped_child.compute_dtype = child.compute_dtype
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = LoraLinear(
lora_class = None
if isinstance(child, nn.Linear):
lora_class = LoraLinear
elif isinstance(child, bnb.nn.Linear8bitLt):
lora_class = LoraLinear8bitLt
elif isinstance(child, bnb.nn.Linear4bit):
lora_class = LoraLinear4bit
if lora_class:
lora_wrapped_child = lora_class(
child,
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
p.requires_grad = False
setattr(module, child_name, lora_wrapped_child)


Expand Down Expand Up @@ -240,6 +235,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
adapter_name,
peft_config["r"],
peft_config["lora_alpha"],
use_rslora=peft_config.get("use_rslora", False),
lora_dropout=peft_config["lora_dropout"],
init_lora_weights=peft_config["init_lora_weights"],
)
Expand Down Expand Up @@ -275,7 +271,7 @@ def estimate_adapter_memory_per_block(
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)
create_lora_adapter(block)

for adapter in adapters:
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
Expand Down