Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DDP prediction and checkpoint Issues #884

Merged
merged 11 commits into from
Jun 5, 2024
Merged

Conversation

shihchengli
Copy link
Contributor

@shihchengli shihchengli commented May 24, 2024

Description

Address issue #874. Please refer to the issue for details.

Relevant issues

#853 #874

Checklist

  • linted with flake8?
  • (if appropriate) unit tests added?

@shihchengli shihchengli added this to the v2.0.1 milestone May 24, 2024
@shihchengli shihchengli linked an issue May 24, 2024 that may be closed by this pull request
@shihchengli shihchengli marked this pull request as ready for review May 24, 2024 22:18
Copy link
Member

@JacksonBurns JacksonBurns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor notes to be addressed.

chemprop/models/model.py Outdated Show resolved Hide resolved
Comment on lines 870 to 881
torch.distributed.destroy_process_group()

best_ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = pl.Trainer(
logger=trainer_logger,
enable_progress_bar=True,
accelerator=args.accelerator,
devices=1,
)
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
model = model.load_from_checkpoint(best_ckpt_path)
predss = trainer.predict(model, dataloaders=test_loader)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check my understanding - we can train and validate in DDP, but we will always test in single-GPU mode? This seems dubious for large datasets/models. Why not have each process reload the best model and continue in DDP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding is correct. DDP uses a distributed sampler, which drops samples to ensure the number of batches divides evenly across the GPUs. This is acceptable during validation to measure model performance. However, for testing, we should evaluate every data point. Even if the data points are evenly distributed across different processes, we would need to save the predictions from different processes, merge them, and also find the indices from the sampler so we can save the results in a file with their SMILES ordering correctly. IMO, the inference of the D-MPNN model is not expensive, so I think using a single GPU here is fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lightning does the opposite of what you are describing - look at the note at the bottom of this section of their docs: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#testing

It says that samples will actually be duplicated when batches don't evenly divide across GPUs. It even suggests running validate (and I think it means to suggest for test, too) on only a single device to avoid this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of this I think what we have here is sound. My above comment might be a good thing to think about in the future, but I don't think it's a big deal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@JacksonBurns
Copy link
Member

@shihchengli please update branch, @KnathanM please take a quick look at this and then I think it should be ready to merge

@shihchengli shihchengli force-pushed the fix_ddp branch 2 times, most recently from fc5f7eb to 43ee8e2 Compare May 31, 2024 00:42
Comment on lines 879 to 880
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
model = model.load_from_checkpoint(best_ckpt_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
model = model.load_from_checkpoint(best_ckpt_path)
model = model.load_from_checkpoint(best_ckpt_path)

I don't think you need to use build_model again. load_from_checkpoint is a class method so you could even do MPNN.load_from_checkpoint() without a model (though then you'd have to check if it is multicomponent). If you think model = model.load_from_checkpoint(best_ckpt_path) is unclear, you could consider model = model.__class__.load_from_checkpoint(best_ckpt_path).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I noticed that we don't actually need to reload the model weights. The model for different processes in the DDP should have the same model weights, so we can just use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do need to reload the model weights because we want to use the best model, not the last one for testing. When we do predss = trainer.predict(dataloaders=test_loader) the trainer remembers which model checkpoint is the best and uses that but if we make a new trainer I don't think it has that information and will use the most recent model weights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KnathanM You are right! Just change it back.

@KnathanM KnathanM merged commit 9f755b0 into chemprop:main Jun 5, 2024
13 checks passed
@shihchengli shihchengli deleted the fix_ddp branch June 5, 2024 20:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[v2 BUG]: LightningModule's DDP doesn't work
3 participants