Skip to content

Commit

Permalink
Merge pull request #225 from kai422/main
Browse files Browse the repository at this point in the history
correct ish_reshaper
  • Loading branch information
zjysteven committed Feb 6, 2024
2 parents 2010c77 + 5ddf27f commit 18c6f51
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions openood/trainers/ish_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,32 @@ def forward(
x: torch.Tensor,
weight: nn.Parameter,
bias: nn.Parameter,
dropiter
ish_reshaper
):
ctx.dropiter = dropiter
ctx.ish_reshaper = ish_reshaper
ctx.x_shape = x.shape
ctx.has_bias = bias is not None
ctx.save_for_backward(dropiter.select(x, ctx), weight)
ctx.save_for_backward(ish_reshaper.select(x, ctx), weight)
return F.linear(x, weight, bias)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
x, weight = ctx.saved_tensors
grad_bias = torch.sum(grad_output, list(range(grad_output.dim()-1))) if ctx.has_bias else None
ic, oc = weight.shape
x = ctx.dropiter.pad(x, ctx)
x = ctx.ish_reshaper.pad(x, ctx)
grad_weight = grad_output.view(-1,ic).T.mm(x.view(-1,oc))
grad_input = torch.matmul(grad_output, weight, out=x.view(ctx.x_shape))
return grad_input, grad_weight, grad_bias, None

linear_forward = _ISHTLinear.apply
_linear_forward = _ISHTLinear.apply

def linear_forward(self, x):
if self.training:
x = _linear_forward(x, self.weight, self.bias, self.ish_reshaper)
else:
x = F.linear(x, self.weight, self.bias)
return x

supports = {
nn.Linear: linear_forward,
Expand Down Expand Up @@ -152,7 +158,6 @@ def cache_minksample_expscale(self, x: torch.Tensor, ctx=None):
return x

def load_minksample_expscale(self, x, ctx=None):
print(ctx.idxs.shape, x.shape, )
return torch.zeros(
ctx.shape, device=x.device, dtype=x.dtype
).scatter_(1, ctx.idxs, x)
Expand Down Expand Up @@ -217,7 +222,6 @@ def load_minksample_lnscale(self, x, ctx=None):
def transfer(model, strategy, gamma, autocast):
_type = type(model)
ish_reshaper = ISHReshaper(strategy, gamma)
ish_reshaper.autocast = autocast # just for recording
model.forward = partial(supports[_type], model)
model.ish_reshaper = ish_reshaper
print(f"{_type}.forward => ish.{strategy}.{_type}.forward")
Expand Down

0 comments on commit 18c6f51

Please sign in to comment.