Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xander updates #185

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
168 changes: 129 additions & 39 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
prepare_clip_model_sets,
evaluate_pipe,
UNET_EXTENDED_TARGET_REPLACE,
parse_safeloras_embeds,
apply_learned_embed_in_clip,
)

def preview_training_batch(train_dataloader, mode, n_imgs = 40):
Expand All @@ -67,6 +69,52 @@ def preview_training_batch(train_dataloader, mode, n_imgs = 40):
print(f"\nSaved {imgs_saved} preview training imgs to {outdir}")
return

def sim_matrix(a, b, eps=1e-8):
"""
added eps for numerical stability
"""
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt


def compute_pairwise_distances(x,y):
# compute the L2 distance of each row in x to each row in y (both are torch tensors)
# x is a torch tensor of shape (m, d)
# y is a torch tensor of shape (n, d)
# returns a torch tensor of shape (m, n)

n = y.shape[0]
m = x.shape[0]
d = x.shape[1]

x = x.unsqueeze(1).expand(m, n, d)
y = y.unsqueeze(0).expand(m, n, d)

return torch.pow(x - y, 2).sum(2)


def print_most_similar_tokens(tokenizer, optimized_token, text_encoder, n=10):
with torch.no_grad():
# get all the token embeddings:
token_embeds = text_encoder.get_input_embeddings().weight.data

# Compute the cosine-similarity between the optimized tokens and all the other tokens
similarity = sim_matrix(optimized_token.unsqueeze(0), token_embeds).squeeze()
similarity = similarity.detach().cpu().numpy()

distances = compute_pairwise_distances(optimized_token.unsqueeze(0), token_embeds).squeeze()
distances = distances.detach().cpu().numpy()

# print similarity for the most similar tokens:
most_similar_tokens = np.argsort(similarity)[::-1]

print(f"{tokenizer.decode(most_similar_tokens[0])} --> mean: {optimized_token.mean().item():.3f}, std: {optimized_token.std().item():.3f}, norm: {optimized_token.norm():.4f}")
for token_id in most_similar_tokens[1:n+1]:
print(f"sim of {similarity[token_id]:.3f} & L2 of {distances[token_id]:.3f} with \"{tokenizer.decode(token_id)}\"")


def get_models(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -139,19 +187,21 @@ def get_models(
pretrained_vae_name_or_path or pretrained_model_name_or_path,
subfolder=None if pretrained_vae_name_or_path else "vae",
revision=None if pretrained_vae_name_or_path else revision,
local_files_only = True,
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="unet",
revision=revision,
local_files_only = True,
)

return (
text_encoder.to(device),
vae.to(device),
unet.to(device),
tokenizer,
placeholder_token_ids,
placeholder_token_ids
)


Expand Down Expand Up @@ -477,12 +527,13 @@ def train_inversion(

if global_step % accum_iter == 0:
# print gradient of text encoder embedding
print(
text_encoder.get_input_embeddings()
.weight.grad[index_updates, :]
.norm(dim=-1)
.mean()
)
if 0:
print(
text_encoder.get_input_embeddings()
.weight.grad[index_updates, :]
.norm(dim=-1)
.mean()
)
optimizer.step()
optimizer.zero_grad()

Expand Down Expand Up @@ -517,8 +568,10 @@ def train_inversion(
index_no_updates
] = orig_embeds_params[index_no_updates]

for i, t in enumerate(optimizing_embeds):
print(f"token {i} --> mean: {t.mean().item():.3f}, std: {t.std().item():.3f}, norm: {t.norm():.4f}")
if global_step % 50 == 0:
print("------------------------------")
for i, t in enumerate(optimizing_embeds):
print_most_similar_tokens(tokenizer, t, text_encoder)

global_step += 1
progress_bar.update(1)
Expand All @@ -537,7 +590,7 @@ def train_inversion(
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(
save_path, f"step_inv_{global_step}.safetensors"
save_path, f"step_inv_{global_step:04d}.safetensors"
),
save_lora=False,
)
Expand Down Expand Up @@ -583,7 +636,7 @@ def train_inversion(
return

import matplotlib.pyplot as plt
def plot_loss_curve(losses, name, moving_avg=20):
def plot_loss_curve(losses, name, moving_avg=5):
losses = np.array(losses)
losses = np.convolve(losses, np.ones(moving_avg)/moving_avg, mode='valid')
plt.plot(losses)
Expand Down Expand Up @@ -654,7 +707,7 @@ def perform_tuning(
vae,
text_encoder,
scheduler,
optimized_embeddings = text_encoder.get_input_embeddings().weight[:, :],
optimized_embeddings = text_encoder.get_input_embeddings().weight[~index_no_updates, :],
train_inpainting=train_inpainting,
t_mutliplier=0.8,
mixed_precision=True,
Expand Down Expand Up @@ -683,6 +736,12 @@ def perform_tuning(
index_no_updates
] = orig_embeds_params[index_no_updates]

if global_step % 100 == 0:
optimizing_embeds = text_encoder.get_input_embeddings().weight[~index_no_updates]
print("------------------------------")
for i, t in enumerate(optimizing_embeds):
print_most_similar_tokens(tokenizer, t, text_encoder)


global_step += 1

Expand All @@ -696,7 +755,7 @@ def perform_tuning(
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(
save_path, f"step_{global_step}.safetensors"
save_path, f"step_{global_step:04d}.safetensors"
),
target_replace_module_text=lora_clip_target_modules,
target_replace_module_unet=lora_unet_target_modules,
Expand All @@ -706,16 +765,15 @@ def perform_tuning(
.mean()
.item()
)

print("LORA Unet Moved", moved)

moved = (
torch.tensor(
list(itertools.chain(*inspect_lora(text_encoder).values()))
)
.mean()
.item()
)

print("LORA CLIP Moved", moved)

if log_wandb:
Expand Down Expand Up @@ -778,6 +836,7 @@ def train(
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
load_pretrained_inversion_embeddings_path: Optional[str] = None,
seed: int = 42,
resolution: int = 512,
color_jitter: bool = True,
Expand All @@ -788,7 +847,8 @@ def train(
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
lora_rank: int = 4,
lora_rank_unet: int = 4,
lora_rank_text_encoder: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
lora_dropout_p: float = 0.0,
Expand Down Expand Up @@ -825,6 +885,10 @@ def train(
script_start_time = time.time()
torch.manual_seed(seed)

if use_template == "person" and not use_face_segmentation_condition:
print("### WARNING ### : Using person template without face segmentation condition")
print("When training people, it is highly recommended to use face segmentation condition!!")

# Get a dict with all the arguments:
args_dict = locals()

Expand All @@ -841,7 +905,7 @@ def train(

if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
# print(placeholder_tokens, initializer_tokens)

if len(placeholder_tokens) == 0:
placeholder_tokens = []
print("PTI : Placeholder Tokens not given, using null token")
Expand Down Expand Up @@ -874,6 +938,7 @@ def train(

print("PTI : Placeholder Tokens", placeholder_tokens)
print("PTI : Initializer Tokens", initializer_tokens)
print("PTI : Token Map: ", token_map)

# get the models
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
Expand All @@ -886,7 +951,8 @@ def train(
)

noise_scheduler = DDPMScheduler.from_config(
pretrained_model_name_or_path, subfolder="scheduler"
pretrained_model_name_or_path, subfolder="scheduler",
local_files_only = True,
)

if gradient_checkpointing:
Expand Down Expand Up @@ -925,8 +991,6 @@ def train(
train_inpainting=train_inpainting,
)

train_dataset.blur_amount = 200

if train_inpainting:
assert not cached_latents, "Cached latents not supported for inpainting"

Expand Down Expand Up @@ -963,7 +1027,7 @@ def train(
vae = None

# STEP 1 : Perform Inversion
if perform_inversion and not cached_latents:
if perform_inversion and not cached_latents and (load_pretrained_inversion_embeddings_path is None):
preview_training_batch(train_dataloader, "inversion")

print("PTI : Performing Inversion")
Expand Down Expand Up @@ -1014,34 +1078,44 @@ def train(
del ti_optimizer
print("############### Inversion Done ###############")

elif load_pretrained_inversion_embeddings_path is not None:

print("PTI : Loading pretrained inversion embeddings..")
from safetensors.torch import safe_open
# Load the pretrained embeddings from the lora file:
safeloras = safe_open(load_pretrained_inversion_embeddings_path, framework="pt", device="cpu")
#monkeypatch_or_replace_safeloras(pipe, safeloras)
tok_dict = parse_safeloras_embeds(safeloras)
apply_learned_embed_in_clip(
tok_dict,
text_encoder,
tokenizer,
idempotent=True,
)

# Next perform Tuning with LoRA:
if not use_extended_lora:
unet_lora_params, _ = inject_trainable_lora(
unet,
r=lora_rank,
r=lora_rank_unet,
target_replace_module=lora_unet_target_modules,
dropout_p=lora_dropout_p,
scale=lora_scale,
)
print("PTI : not use_extended_lora...")
print("PTI : Will replace modules: ", lora_unet_target_modules)
else:
print("PTI : USING EXTENDED UNET!!!")
lora_unet_target_modules = (
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
)
print("PTI : Will replace modules: ", lora_unet_target_modules)
unet_lora_params, _ = inject_trainable_lora_extended(
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
unet, r=lora_rank_unet, target_replace_module=lora_unet_target_modules
)

n_optimizable_unet_params = sum(
[el.numel() for el in itertools.chain(*unet_lora_params)]
)
print("PTI : n_optimizable_unet_params: ", n_optimizable_unet_params)

print(f"PTI : has {len(unet_lora_params)} lora")
print("PTI : Before training:")
inspect_lora(unet)
#n_optimizable_unet_params = sum([el.numel() for el in itertools.chain(*unet_lora_params)])
#print("PTI : Number of optimizable UNET parameters: ", n_optimizable_unet_params)

params_to_optimize = [
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
Expand Down Expand Up @@ -1073,15 +1147,15 @@ def train(
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder,
target_replace_module=lora_clip_target_modules,
r=lora_rank,
r=lora_rank_text_encoder,
)
params_to_optimize += [
{
"params": itertools.chain(*text_encoder_lora_params),
"lr": text_encoder_lr,
}
{"params": itertools.chain(*text_encoder_lora_params),
"lr": text_encoder_lr}
]
inspect_lora(text_encoder)

#n_optimizable_text_Encoder_params = sum( [el.numel() for el in itertools.chain(*text_encoder_lora_params)])
#print("PTI : Number of optimizable text-encoder parameters: ", n_optimizable_text_Encoder_params)

lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)

Expand All @@ -1090,8 +1164,6 @@ def train(
print("Training text encoder!")
text_encoder.train()

train_dataset.blur_amount = 70

lr_scheduler_lora = get_scheduler(
lr_scheduler_lora,
optimizer=lora_optimizers,
Expand All @@ -1101,6 +1173,22 @@ def train(
if not cached_latents:
preview_training_batch(train_dataloader, "tuning")

#print("PTI : n_optimizable_unet_params: ", n_optimizable_unet_params)
print(f"PTI : has {len(unet_lora_params)} lora")
print("PTI : Before training:")

moved = (
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
.mean().item())
print(f"LORA Unet Moved {moved:.6f}")


moved = (
torch.tensor(
list(itertools.chain(*inspect_lora(text_encoder).values()))
).mean().item())
print(f"LORA CLIP Moved {moved:.6f}")

perform_tuning(
unet,
vae,
Expand Down Expand Up @@ -1132,6 +1220,8 @@ def train(
training_time = time.time() - script_start_time
print(f"Training time: {training_time/60:.1f} minutes")
args_dict["training_time_s"] = int(training_time)
args_dict["n_epochs"] = math.ceil(max_train_steps_tuning / len(train_dataloader.dataset))
args_dict["n_training_imgs"] = len(train_dataloader.dataset)

# Save the args_dict to the output directory as a json file:
with open(os.path.join(output_dir, "lora_training_args.json"), "w") as f:
Expand Down