Skip to content

Commit

Permalink
Merge pull request #223 from zjysteven/main
Browse files Browse the repository at this point in the history
Add SyncBN and DDP sampler's set_epoch
  • Loading branch information
zjysteven committed Feb 5, 2024
2 parents e6a1b19 + aa57526 commit 2010c77
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
8 changes: 7 additions & 1 deletion openood/pipelines/train_oe_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def run(self):

# init network
net = get_network(self.config.network)
if self.config.num_gpus * self.config.num_machines > 1:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

# init trainer and evaluator
trainer = get_trainer(net, [train_loader, train_oe_loader], None,
Expand All @@ -39,9 +41,13 @@ def run(self):
if comm.is_main_process():
# init recorder
recorder = get_recorder(self.config)

print('Start training...', flush=True)

for epoch_idx in range(1, self.config.optimizer.num_epochs + 1):
if isinstance(train_loader.sampler,
torch.utils.data.distributed.DistributedSampler):
train_loader.sampler.set_epoch(epoch_idx - 1)

# train and eval the model
net, train_metrics = trainer.train_epoch(epoch_idx)
val_metrics = evaluator.eval_acc(net, val_loader, None, epoch_idx)
Expand Down
6 changes: 6 additions & 0 deletions openood/pipelines/train_opengan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def run(self):

# init network
net = get_network(self.config.network)
if self.config.num_gpus * self.config.num_machines > 1:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

# init trainer
trainer = get_trainer(net, dataloaders['id_train'],
Expand All @@ -46,6 +48,10 @@ def run(self):

print('Start training...', flush=True)
for epoch_idx in range(1, self.config.optimizer.num_epochs + 1):
if isinstance(dataloaders['id_train'].sampler,
torch.utils.data.distributed.DistributedSampler):
dataloaders['id_train'].sampler.set_epoch(epoch_idx - 1)

# train the model
net, train_metrics = trainer.train_epoch(epoch_idx)
val_metrics = evaluator.eval_ood_val(net, id_loaders, ood_loaders,
Expand Down
8 changes: 7 additions & 1 deletion openood/pipelines/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def run(self):

# init network
net = get_network(self.config.network)
if self.config.num_gpus * self.config.num_machines > 1:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

# init trainer and evaluator
trainer = get_trainer(net, train_loader, val_loader, self.config)
Expand All @@ -37,9 +39,13 @@ def run(self):
if comm.is_main_process():
# init recorder
recorder = get_recorder(self.config)

print('Start training...', flush=True)

for epoch_idx in range(1, self.config.optimizer.num_epochs + 1):
if isinstance(train_loader.sampler,
torch.utils.data.distributed.DistributedSampler):
train_loader.sampler.set_epoch(epoch_idx - 1)

# train and eval the model
if self.config.trainer.name == 'mos':
net, train_metrics, num_groups, group_slices = \
Expand Down

0 comments on commit 2010c77

Please sign in to comment.