diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 3f90ac8..82e3b6a 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -275,6 +275,7 @@ def train( test_loader, is_test=True, test_result_save_path=None ) self.training_history[key]["test"] = [test_mae[key] for key in self.targets] + self.save(filename=os.path.join(save_dir, file)) def _train(self, train_loader: DataLoader, current_epoch: int) -> dict: """Train all data for one epoch.