You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def update_ipt(self, model):
# Update the sensitivity and uncertainty for every weight
for n, p in model.named_parameters():
if "lora_" in n and self.adapter_name in n:
if n not in self.ipt:
self.ipt[n] = torch.zeros_like(p)
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
self.exp_avg_unc[n] = (
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
)
When using adalora peft, the classification header layer includes:
after checking, there is no gradient. In other words, the requires_grad attribute is False, but the inclulde "lora_" string. I think gradient checking should be added to the update_ipt function.
This error occurs when calling model.update_and_allocate(global_step).
Who can help?
No response
Information
The official example scripts
My own modified scripts
Tasks
An officially supported task in the examples folder
My own task or dataset (give details below)
Reproduction
This error occurs when calling model.update_and_allocate(global_step).
I think gradient checking should be added to the update_ipt function.
def update_ipt(self, model):
# Update the sensitivity and uncertainty for every weight
for n, p in model.named_parameters():
if not p.requires_grad: continue
if "lora_" in n and self.adapter_name in n:
if n not in self.ipt:
self.ipt[n] = torch.zeros_like(p)
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
self.exp_avg_unc[n] = (
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
)
The text was updated successfully, but these errors were encountered:
System Info
Adalora
When using adalora peft, the classification header layer includes:
But for layers
after checking, there is no gradient. In other words, the
requires_grad
attribute is False, but the inclulde "lora_" string. I think gradient checking should be added to theupdate_ipt
function.This error occurs when calling model.update_and_allocate(global_step).
Who can help?
No response
Information
Tasks
examples
folderReproduction
This error occurs when calling model.update_and_allocate(global_step).
the config is:
the model is RoBERTa.
Expected behavior
I think gradient checking should be added to the
update_ipt
function.The text was updated successfully, but these errors were encountered: