Skip to content

Commit

Permalink
Merge branch 'main' of github.com:zjysteven/OpenOOD into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zjysteven committed Feb 5, 2024
2 parents 847b189 + e6a1b19 commit aa57526
Show file tree
Hide file tree
Showing 15 changed files with 544 additions and 5 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ distance: f4d5b3 -->
> - [x] [![gen](https://img.shields.io/badge/CVPR'23-GEN-fdd7e6?style=for-the-badge)](https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.pdf)    ![postprocess]
> - [x] [![nnguide](https://img.shields.io/badge/ICCV'23-NNGuide-fdd7e6?style=for-the-badge)](https://arxiv.org/abs/2309.14888)    ![postprocess]
> - [x] [![relation](https://img.shields.io/badge/NEURIPS'23-Relation-fdd7e6?style=for-the-badge)](https://arxiv.org/abs/2301.12321)    ![postprocess]
> - [x] [![scale](https://img.shields.io/badge/ICLR'24-Scale-fdd7e6?style=for-the-badge)](https://github.com/kai422/SCALE)    ![postprocess]
> Training Methods (6):
> - [x] [![confbranch](https://img.shields.io/badge/arXiv'18-ConfBranch-fdd7e6?style=for-the-badge)](https://github.com/uoguelph-mlrg/confidence_estimation)    ![preprocess]   ![training]
Expand All @@ -251,6 +252,7 @@ distance: f4d5b3 -->
> - [x] [![cider](https://img.shields.io/badge/ICLR'23-CIDER-f4d5b3?style=for-the-badge)](https://github.com/deeplearning-wisc/cider)    ![training]   ![postprocess]
> - [x] [![npos](https://img.shields.io/badge/ICLR'23-NPOS-f4d5b3?style=for-the-badge)](https://github.com/deeplearning-wisc/npos)    ![training]   ![postprocess]
> - [x] [![t2fnorm](https://img.shields.io/badge/arXiv'23-T2FNorm-f4d5b3?style=for-the-badge)](https://arxiv.org/abs/2305.17797)    ![training]
> - [x] [![ish](https://img.shields.io/badge/ICLR'24-ish-fdd7e6?style=for-the-badge)](https://github.com/kai422/SCALE)    ![training]

> Training With Extra Data (3):
Expand Down
44 changes: 44 additions & 0 deletions configs/pipelines/train/train_ish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
exp_name: "'@{dataset.name}'_'@{network.name}'_'@{trainer.name}'_e'@{optimizer.num_epochs}'_lr'@{optimizer.lr}'_param'@{trainer.trainer_args.param}'_bs_'@{dataset.train.batch_size}'/s'@{seed}'"

output_dir: ./results/
save_output: True
merge_option: default
seed: 0

num_gpus: 1
num_workers: 8
num_machines: 1
machine_rank: 0

preprocessor:
name: base

pipeline:
name: train



trainer:
name: ish
trainer_args:
mode: minksample_expscale
param: 0.85
layer: r1


evaluator:
name: base

optimizer:
name: sgd
num_epochs: 100
lr: 0.1
momentum: 0.9
weight_decay: 0.0005
weight_decay_fc: 0.00005
nesterov: True
nesterov_fc: True

recorder:
name: base
save_all_models: False
7 changes: 7 additions & 0 deletions configs/postprocessors/scale.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
postprocessor:
name: scale
APS_mode: True
postprocessor_args:
percentile: 85
postprocessor_sweep:
percentile_list: [65, 70, 75, 80, 85, 90, 95]
3 changes: 3 additions & 0 deletions openood/evaluation_api/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openood.postprocessors import BasePostprocessor
from openood.networks.ash_net import ASHNet
from openood.networks.react_net import ReactNet
from openood.networks.scale_net import ScaleNet

from .datasets import DATA_INFO, data_setup, get_id_ood_dataloader
from .postprocessor import get_postprocessor
Expand Down Expand Up @@ -113,6 +114,8 @@ def __init__(
net = ReactNet(net)
elif postprocessor_name == 'ash':
net = ASHNet(net)
elif postprocessor_name == 'scale':
net = ScaleNet(net)

# postprocessor setup
postprocessor.setup(net, dataloader_dict['id'], dataloader_dict['ood'])
Expand Down
9 changes: 5 additions & 4 deletions openood/evaluation_api/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
GradNormPostprocessor, GRAMPostprocessor, KLMatchingPostprocessor, KNNPostprocessor,
MaxLogitPostprocessor, MCDPostprocessor, MDSPostprocessor, MDSEnsemblePostprocessor,
MOSPostprocessor, ODINPostprocessor, OpenGanPostprocessor, OpenMax, PatchcorePostprocessor,
Rd4adPostprocessor, ReactPostprocessor, ResidualPostprocessor, SSDPostprocessor,
TemperatureScalingPostprocessor, VIMPostprocessor, RotPredPostprocessor, RankFeatPostprocessor,
RMDSPostprocessor, SHEPostprocessor, CIDERPostprocessor, NPOSPostprocessor, GENPostprocessor,
NNGuidePostprocessor, RelationPostprocessor)
Rd4adPostprocessor, ReactPostprocessor, ResidualPostprocessor, ScalePostprocessor,
SSDPostprocessor, TemperatureScalingPostprocessor, VIMPostprocessor, RotPredPostprocessor,
RankFeatPostprocessor, RMDSPostprocessor, SHEPostprocessor, CIDERPostprocessor, NPOSPostprocessor,
GENPostprocessor, NNGuidePostprocessor, RelationPostprocessor)
from openood.utils.config import Config, merge_configs

postprocessors = {
Expand Down Expand Up @@ -48,6 +48,7 @@
'opengan': OpenGanPostprocessor,
'knn': KNNPostprocessor,
'dice': DICEPostprocessor,
'scale': ScalePostprocessor,
'ssd': SSDPostprocessor,
'she': SHEPostprocessor,
'rd4ad': Rd4adPostprocessor,
Expand Down
49 changes: 49 additions & 0 deletions openood/networks/scale_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch
import torch.nn as nn


class ScaleNet(nn.Module):
def __init__(self, backbone):
super(ScaleNet, self).__init__()
self.backbone = backbone

def forward(self, x, return_feature=False, return_feature_list=False):
try:
return self.backbone(x, return_feature, return_feature_list)
except TypeError:
return self.backbone(x, return_feature)

def forward_threshold(self, x, percentile):
_, feature = self.backbone(x, return_feature=True)
feature = scale(feature.view(feature.size(0), -1, 1, 1), percentile)
feature = feature.view(feature.size(0), -1)
logits_cls = self.backbone.get_fc_layer()(feature)
return logits_cls

def get_fc(self):
fc = self.backbone.fc
return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()


def scale(x, percentile=65):
input = x.clone()
assert x.dim() == 4
assert 0 <= percentile <= 100
b, c, h, w = x.shape

# calculate the sum of the input per sample
s1 = x.sum(dim=[1, 2, 3])
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile / 100.0))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
t.zero_().scatter_(dim=1, index=i, src=v)

# calculate new sum of the input per sample after pruning
s2 = x.sum(dim=[1, 2, 3])

# apply sharpening
scale = s1 / s2

return input * torch.exp(scale[:, None, None, None])
1 change: 1 addition & 0 deletions openood/postprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .react_postprocessor import ReactPostprocessor
from .rmds_postprocessor import RMDSPostprocessor
from .residual_postprocessor import ResidualPostprocessor
from .scale_postprocessor import ScalePostprocessor
from .ssd_postprocessor import SSDPostprocessor
from .she_postprocessor import SHEPostprocessor
from .temp_scaling_postprocessor import TemperatureScalingPostprocessor
Expand Down
29 changes: 29 additions & 0 deletions openood/postprocessors/scale_postprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .base_postprocessor import BasePostprocessor


class ScalePostprocessor(BasePostprocessor):
def __init__(self, config):
super(ScalePostprocessor, self).__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.percentile = self.args.percentile
self.args_dict = self.config.postprocessor.postprocessor_sweep

@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
output = net.forward_threshold(data, self.percentile)
_, pred = torch.max(output, dim=1)
energyconf = torch.logsumexp(output.data.cpu(), dim=1)
return pred, energyconf

def set_hyperparam(self, hyperparam: list):
self.percentile = hyperparam[0]

def get_hyperparam(self):
return self.percentile
Loading

0 comments on commit aa57526

Please sign in to comment.