Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add nonlinearity between Down and Up weights, expose alpha hyperparameter #111

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
8 changes: 4 additions & 4 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def add(
path_1,
).to("cpu")

weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha)
if with_text_lora:

weight_apply_lora(
loaded_pipeline.text_encoder,
torch.load(_text_lora_path(path_2)),
alpha=alpha,
scale=alpha,
target_replace_module=["CLIPAttention"],
)

Expand All @@ -93,12 +93,12 @@ def add(
path_1,
).to("cpu")

weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha)
if with_text_lora:
weight_apply_lora(
loaded_pipeline.text_encoder,
torch.load(_text_lora_path(path_2)),
alpha=alpha,
scale=alpha,
target_replace_module=["CLIPAttention"],
)

Expand Down
95 changes: 76 additions & 19 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,39 @@


class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
def __init__(self, in_features, out_features, bias=False, r=4, scale=1.0, init=None, nonlin: nn.Module = None):
super().__init__()

if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)


if scale <= 0:
raise ValueError(
f"LoRA scale {scale} must be greater than 0"
)

self.r = r
self.scale = scale
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.nonlin = nonlin if nonlin else None
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = 1.0

nn.init.normal_(self.lora_down.weight, std=1 / r)
if init=="kaiming":
pass
# Kaiming with a=math.sqrt(5) is default for nn.Linear
else:
nn.init.normal_(self.lora_down.weight, std=1 / r)

nn.init.zeros_(self.lora_up.weight)

def forward(self, input):
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
if self.nonlin:
return self.linear(input) + self.lora_up(self.nonlin(self.lora_down(input))) * self.scale
else:
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale


UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
Expand Down Expand Up @@ -116,6 +131,9 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
init=None,
nonlin=None,
loras=None, # path to lora .pt
):
"""
Expand All @@ -137,7 +155,10 @@ def inject_trainable_lora(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r,
r=r,
scale=scale,
init=init,
nonlin=nonlin,
)
_tmp.linear.weight = weight
if bias is not None:
Expand Down Expand Up @@ -333,9 +354,13 @@ def load_safeloras(path, device="cpu"):


def weight_apply_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0
):

model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
):
for _m, _n, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear]
):
Expand All @@ -344,13 +369,22 @@ def weight_apply_lora(
up_weight = loras.pop(0).detach().to(weight.device)
down_weight = loras.pop(0).detach().to(weight.device)

# W <- W + U * D
weight = weight + alpha * (up_weight @ down_weight).type(weight.dtype)
if nonlin is None:
# W <- W + U * D
weight = weight + scale * (up_weight @ down_weight).type(weight.dtype)
else:
# W <- W + U * nonlin(D)
weight = weight + scale * (up_weight @ nonlin(down_weight)).type(weight.dtype)

_child_module.weight = nn.Parameter(weight)


def monkeypatch_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear]
Expand All @@ -362,6 +396,8 @@ def monkeypatch_lora(
_child_module.out_features,
_child_module.bias is not None,
r=r,
scale=scale,
nonlin=nonlin,
)
_tmp.linear.weight = weight

Expand All @@ -385,7 +421,12 @@ def monkeypatch_lora(


def monkeypatch_replace_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[LoraInjectedLinear]
Expand All @@ -397,7 +438,9 @@ def monkeypatch_replace_lora(
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
scale=scale,
nonlin=nonlin,
)
_tmp.linear.weight = weight

if bias is not None:
Expand All @@ -424,6 +467,8 @@ def monkeypatch_or_replace_lora(
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: Union[int, List[int]] = 4,
scale: Union[float, List[float]] = 1.0,
nonlin: Union[float, List[nn.Module]] = None,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
Expand All @@ -441,6 +486,8 @@ def monkeypatch_or_replace_lora(
_source.out_features,
_source.bias is not None,
r=r.pop(0) if isinstance(r, list) else r,
scale=scale.pop(0) if isinstance(scale, list) else scale,
nonlin=nonlin.pop(0) if isinstance(nonlin, list) else nonlin,
)
_tmp.linear.weight = weight

Expand Down Expand Up @@ -496,7 +543,7 @@ def monkeypatch_add_lora(
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
alpha: float = 1.0,
scale: float = 1.0,
beta: float = 1.0,
):
for _module, name, _child_module in _find_modules(
Expand All @@ -519,12 +566,16 @@ def monkeypatch_add_lora(
_module._modules[name].to(weight.device)


def tune_lora_scale(model, alpha: float = 1.0):
def tune_lora_scale(model, alpha: float = 1.0, scale: float = None):
if alpha:
# Keep original named parameter alpha (which is really scale),
scale = alpha

for _module in model.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
_module.scale = alpha

_module.scale = scale


def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
Expand Down Expand Up @@ -576,6 +627,8 @@ def patch_pipe(
unet_path,
token: str,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
patch_unet=True,
patch_text=False,
patch_ti=False,
Expand All @@ -596,6 +649,8 @@ def patch_pipe(
pipe.unet,
torch.load(unet_path),
r=r,
scale=scale,
nonlin=nonlin,
target_replace_module=unet_target_replace_module,
)

Expand All @@ -606,6 +661,8 @@ def patch_pipe(
torch.load(text_path),
target_replace_module=text_target_replace_module,
r=r,
scale=scale,
nonlin=nonlin,
)
if patch_ti:
print("LoRA : Patching token input")
Expand Down