Skip to content

Commit

Permalink
✨ Make LR scheduling more robust in the Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Jun 13, 2024
1 parent 22ea5c6 commit ee80b1c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
4 changes: 3 additions & 1 deletion docs/tutorial/training/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Let's explore all the available parameters:
- **weight_decay** (float): Optimizer weight decay value.
- **lr_scheduler** (LRSchedulerType): Optional learning rate scheduler among `LRSchedulerType` enum.
- **lr_scheduler_kwargs** (Dict[str, Any]): LR scheduler instructor kwargs depending on the scheduler type
- **lr_scheduling_steps** (int): Number of steps to perform scheduler stepping. If left as None, will default to the steps in one full epoch.
- **batch_size** (int): Training batch size.
- **eval_batch_size** (int): Evaluation batch size, defaults to `batch_size` if None.
- **gradient_accumulation_steps** (int): Number of updates steps to accumulate before performing a backward/update pass,
Expand Down Expand Up @@ -242,7 +243,8 @@ the number of batches present in the data loader:
does not happen here since it has its own method.
3. Optimization step (`optimization_step`): Does the optimizer stepping and zeros gradients afterward. (Gradient
accumulation is handled by the accelerator)
4. Update loss tracker and the trainer states.
4. LR scheduling: Depending on `lr_scheduling_steps`, perform one step of LR scheduling.
5. Update loss tracker and the trainer states.
5. Update and show the loss moving average in the progress bar.
6. Perform saving and logging according to `save_steps` and `log_steps`.
7. Return average loss up until now. (This value is accumulated and averaged since the beginning of the whole training
Expand Down
3 changes: 3 additions & 0 deletions hezar/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ class TrainerConfig(Config):
Optional learning rate scheduler among `LRSchedulerType` enum.
lr_scheduler_kwargs (Dict[str, Any]):
LR scheduler instructor kwargs depending on the scheduler type
lr_scheduling_steps (int):
Number of steps to perform scheduler stepping. If left as None, will default to the steps in one full epoch.
batch_size (int):
Training batch size.
eval_batch_size (int):
Expand Down Expand Up @@ -473,6 +475,7 @@ class TrainerConfig(Config):
weight_decay: float = 0.0
lr_scheduler: str | LRSchedulerType = None
lr_scheduler_kwargs: Dict[str, Any] = None
lr_scheduling_steps: int = None
batch_size: int = None
eval_batch_size: int = None
gradient_accumulation_steps: int = 1
Expand Down
34 changes: 17 additions & 17 deletions hezar/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
get_distributed_logger,
resolve_logdir,
write_to_tensorboard,
get_lr_scheduler_type,
)


Expand All @@ -72,6 +73,7 @@
}
lr_schedulers = {
LRSchedulerType.LAMBDA: torch.optim.lr_scheduler.LambdaLR,
LRSchedulerType.REDUCE_ON_PLATEAU: torch.optim.lr_scheduler.ReduceLROnPlateau,
LRSchedulerType.STEP: torch.optim.lr_scheduler.StepLR,
LRSchedulerType.MULTI_STEP: torch.optim.lr_scheduler.MultiStepLR,
LRSchedulerType.ONE_CYCLE: torch.optim.lr_scheduler.OneCycleLR,
Expand Down Expand Up @@ -186,7 +188,10 @@ def __init__(

# Setup optimizer and (optionally) lr scheduler
self.optimizer, self.lr_scheduler = self._create_optimizers(optimizer, lr_scheduler)
self.lr_scheduling_steps = self.config.lr_scheduling_steps or self.steps_in_epoch
self.lr_scheduler_type = get_lr_scheduler_type(self.lr_scheduler, lr_schedulers)

# Move objects to the accelerator
self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
Expand Down Expand Up @@ -464,21 +469,6 @@ def optimization_step(self):
self.optimizer.step()
self.optimizer.zero_grad()

def lr_scheduler_step(self, metrics=None):
"""
Perform the learning rate scheduling step
Args:
metrics: one or multiple values that the scheduler watches to either perform step function or not. Only
works for `ReduceLROnPlateau`.
"""
if self.lr_scheduler is not None:
if isinstance(self.lr_scheduler, lr_schedulers[LRSchedulerType.REDUCE_ON_PLATEAU]):
if metrics:
self.lr_scheduler.step(metrics)
else:
self.lr_scheduler.step()

def training_step(self, input_batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""
Train one batch of data and return loss and model outputs
Expand Down Expand Up @@ -582,6 +572,14 @@ def inner_training_loop(self, epoch_num: int):
self.state.loss_tracker_sum = self.train_loss_tracker.sum
accumulated_loss = 0

# Scheduler step
if (
self.lr_scheduler is not None and
self.state.global_step % self.lr_scheduling_steps == 0 and
self.lr_scheduler_type != LRSchedulerType.REDUCE_ON_PLATEAU
):
self.lr_scheduler.step()

# Save trainer outputs if `save_steps` is hit
if self.config.save_steps and self.state.global_step % self.config.save_steps == 0:
ckpt_path_name = str(self.state.global_step).zfill(len(str(self.total_steps)))
Expand All @@ -593,6 +591,7 @@ def inner_training_loop(self, epoch_num: int):
self.trainer_state_file,
)
)

# Log loss running mean
if self.config.log_steps and self.state.global_step % self.config.log_steps == 0:
loss_mean = {"train.loss": self.train_loss_tracker.avg}
Expand Down Expand Up @@ -733,8 +732,9 @@ def train(self, resume_from_checkpoint="deprecated"):
}
metrics_logs.update(evaluation_logs)

# LR scheduler step
self.lr_scheduler_step(metrics_logs[self.config.metric_for_best_model])
# LR scheduler step (only for reduce on plateau)
if self.lr_scheduler_type == LRSchedulerType.REDUCE_ON_PLATEAU:
self.lr_scheduler.step(metrics_logs[self.config.metric_for_best_model])

# Update trainer state
self.state.epoch = epoch
Expand Down
6 changes: 6 additions & 0 deletions hezar/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"write_to_tensorboard",
"resolve_logdir",
"get_distributed_logger",
"get_lr_scheduler_type",
]


Expand Down Expand Up @@ -188,3 +189,8 @@ def get_distributed_logger(name: str, level: str = None, fmt: str = None):
logger.logger.addHandler(handler)

return logger

def get_lr_scheduler_type(lr_scheduler, schedulers_mapping: dict):
for name, scheduler_cls in schedulers_mapping.items():
if isinstance(lr_scheduler, scheduler_cls):
return name

0 comments on commit ee80b1c

Please sign in to comment.