Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sinusoidal Learning rate when increasing both max_epochs and batch size #107

Closed
JJrodny opened this issue May 17, 2024 · 10 comments
Closed

Comments

@JJrodny
Copy link

JJrodny commented May 17, 2024

I'd like to start this ticket off with: This repo is amazing! We're just starting to train models with it and your work is SotA and the weights are so tiny! Thank you for all of your hard work!

The current issue @gvoysey and I are working on is training for longer and with larger batch sizes.

Please correct my understanding if any of this is wrong.

To increase the batch sizes on the GPU RAM, I go into the specific configs/experiment/semantic/*.yaml file that I'm using and increase sample_graph_k:

datamodule:
  sample_graph_k: 4

As I understand it, increasing this number increases the nag batch size loaded onto the GPU.

Additionally, I can also increase the batch size theoretically without overloading the GPU, by increasing the gradient_accumulator:

callbacks:
  gradient_accumulator:  
    scheduling:  
      0:  
        10

By implementing both of these, I now have a batch size of 40 - every batch of 4 run on the GPU together, but we don't update the gradient until we get 10 of these batches, so it allows us to train on larger batches (40) than our GPU can fit (4).

Now, I also want to train for longer.

In the comments for max_epochs there's a line:
# to keep same nb of steps: 25/9x more tiles, 2-step gradient accumulation -> epochs * 2 * 9 / 25
which implies if I train with a larger batch I need to reduce max_epochs, and if I train with a smaller batch, I increase max_epochs.

But if I want to train for more epochs while also increasing gradient accumulation or batch size, I can just increase max_epochs, right?

The problem we're coming up to is that while increasing max_epochs does allow us to train for longer, the learning rate (src.optim.CosineAnnealingLRWithWarmup) doesn't go down to zero at max_epochs. Instead it goes down and back up in a sinusoidal way, where it first goes down to zero at the value of max_epochs I should have set it to based on the formula in the comment above that takes into account the modified batch size.

This implies we should only train a model for a set number of iterations. Is that true? Can we change that?

In other words, how do we increase the max_epochs and keep our batch size large, while having the learning rate appropriately decrease to zero at the max_epochs value?

Here's an example of the wandb output in text format of the learning rate (for max_epochs set to 2000, and effective batch size of 40 as outlined above):

wandb: Run history:
wandb:                         epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:                  lr-AdamW/pg1 ▇█▇▆▅▃▂▁▁▁▂▄▅▇███▇▆▄▃▂▁▁▂▃▄▅▇███▇▆▄▃▂▁▁▁
wandb:         lr-AdamW/pg1-momentum ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                  lr-AdamW/pg2 ▇█▇▆▅▃▂▁▁▁▂▄▅▇███▇▆▄▃▂▁▁▂▃▄▅▇███▇▆▄▃▂▁▁▁
wandb:         lr-AdamW/pg2-momentum ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

And attached is a graph of the learning rate:
Screenshot from 2024-05-17 16-33-18

Any help or advice would be appreciated! Thank you so much Damien!

@drprojects
Copy link
Owner

Hi @JJrodny and @gvoysey, thanks for the well-detailed issue, as always 😉

By implementing both of these, I now have a batch size of 40 - every batch of 4 run on the GPU together, but we don't update the gradient until we get 10 of these batches, so it allows us to train on larger batches (40) than our GPU can fit (4).

Yes, you are correct, your effective batch size will be sample_graph_k * gradient_accumulator.scheduling[0]. Effective in the sense that your models weights will be updated every gradient_accumulator.scheduling[0] forward steps, based on the accumulated gradient across these steps. This trick allows you to train with larger batches despite your GPU being able to hold only a fraction at once.

To be 100% clear, because someone just posted a question related to this: #108. Your effective batch size (ie the number of subgraphs on which your weights will be updated) will be: batch_size * sample_graph_k * gradient_accumulator.scheduling[0], where the misleadingly-named batch_size actually rules how many tiles you load into memory, from each of which sample_graph_k will be sampled.

For your information, another way of increasing the size of your batch is to increase the radius of the sampled subgraphs by playing with sample_graph_r. This may or not make sense, depending on how big your GPU (if you can't fit more in memory, don't do it) is and how big your individual preprocessed tiles are (if the tiles are very large and diverse in content, you may sample more sample_graph_k or larger sample_graph_r without risking too much redundancy between the subgraphs). I do not have a fool-proof recipe for this.

OK, back to your actual issue now.

A consequence of gradient accumulation is that, if you want to train for as many iterations (ie model weights updates) as before, you will need to also adjust the number of epochs you train for. Said otherwise, if you x2 gradient_accumulator, then you need to x2 max_epochs to maintain the number of training steps constant. There is no formal obligation to do so, it is just a constraint I had to allow fair comparison between some of my experiments.

In doing so, I also remember coming across the same issue you have with src.optim.CosineAnnealingLRWithWarmup: setting max_epoch did not produce the expected results.

  • based on my above comment, if you increased gradient_accumulator from 1 (default behavior) to 10, can you please try also multiplying max_epochs by 10 ?

Please let me know, if this helps ! If not, I will try to reproduce this behavior. I know this part of the code is a bit shady 😅

@JJrodny
Copy link
Author

JJrodny commented May 22, 2024

Thank you for the detailed reply!

We're still trying to figure it out, but after trying a few different combinations of different values of sample_graph_k, and gradient_accumulator.scheduling[0] it seems that they don't have an effect on the shape of the learning rate.

At first I was looking at graphs in wandb measuring learning rate against the default x value: trainer/global_step, and that seemed to show that x10 gradient accumulation and x10 max_epochs creates more sinusoids; it seemed to suggest that if we x10 gradient accumulation we should x0.1 max_epochs . But interestingly, when looking at learning rate over epochs (step in wandb), all of these models we trained with different variations for sample_graph_k, and gradient_accumulator.scheduling[0] show the same learning rate curves.

That suggests to me that the lr_scheduler has some hard-coded maximum that training should be (maybe due to the nature of cosine used in the function) so I'm exploring making some modifications here, (although I may be wrong)

elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:

I have two questions that have come up while doing this:

  1. It looks like when I set model.sceduler.num_warmup to anything it prints the value I've set it to in the config that's printed to the terminal, but when training the learning rate just still ramps up to 0.001 at epoch 20

  2. Another question I have - ignoring the sinusoidal problem - we can set model.optimizer.lr and that sets lr-AdamW/pg2, but lr-AdamW/pg1 remains at 0.01. How do we globally change the learning rate over all model parameters?

My ultimate goal in this is to have the learning rate drop down to 0 at the end of training, no matter what we set the max number of epochs to.

Thanks for all of your help!

@drprojects
Copy link
Owner

drprojects commented May 22, 2024

Hi @JJrodny I do not have access to any SPT-ready machine at the moment, will try to look into this by Friday.

@drprojects
Copy link
Owner

Hi @JJrodny I have not forgotten about your issue, but still haven't found time to look into yet, sorry !

@drprojects
Copy link
Owner

HI @JJrodny I have not been able to reproduce your error. Are you training from scratch or are you starting training from a pre-trained checkpoint file ? If the latter, then this is likely the source of the problem.

@gvoysey
Copy link

gvoysey commented May 28, 2024 via email

@drprojects
Copy link
Owner

OK so that's where the problem comes from. I do not support fine-tuning yet, I will need to work on this !

In the meantime, if your dataset is large enough, I would advise just training from scratch. If not, then you will need to tweak the optimizer and scheduler to be well behaved.

I know this is a feature you need, I will try to make time to work on this soon, sorry for the delay 😖

@gardiens
Copy link

yep! we’re starting with dales pretrained and fine tuning on top.

I was right since the start, I don't exactly know how you are loading your checkpoint but you are probably doing this:

model=hydra.utils.instantiate(cfg.model)
model = type(model).load_from_checkpoint(path_ckpt, net=model.net, criterion=model.criterion)

To boil it down, you have to load the checkpoint AND override with your criterion parameters,
by default when you load a model with pytorch lightning you are loading a model AND a optimizer, so if you don't override the args your optimizer and especially the cosine annealing will be tuned for the wrong number of epoch.

What solved the problem for me was to call the load_from_checkpoint with all the args that was provided in the model config so it overrides the optimizer parameters .
I can't share my code because it is too dirty :(

@JJrodny
Copy link
Author

JJrodny commented May 29, 2024

Thank you @drprojects, @gardiens, (and @gvoysey)! it was exactly that that was the problem.

We're training from a pretrained model using load_from_checkpoint, and that was the cause. We were assigning to the criterion parameter, but the problem was that that doesn't update the optimizer or scheduler that we can set in the yaml. It instead uses the pretrained weights' optimizer and scheduler if we don't pass new ones in when we load from checkpoint.

It took me quite some fiddling trying to find the optimizer and scheduler to pass in and after learning some hydra it turned out I just needed to instantiate the new scheduler and optimizer from the config with hydra and pass that in.

new_num_classes = datamodule.hparams.num_classes

model = type(model).load_from_checkpoint(cfg.pretrained_weights,
                                         net=model.net,
                                         criterion=model.criterion,
                                         num_classes=new_num_classes,
                                         strict=False,
                                         optimizer=hydra.utils.instantiate(cfg.model.optimizer),
                                         scheduler=hydra.utils.instantiate(cfg.model.scheduler))

@drprojects I'm so sorry to waste your time and make you boot up the code only for the problem not to be reproducible on your machine, I have been there and I completely understand how that feels!

Thank you all for your help! Problem solved!

@JJrodny JJrodny closed this as completed May 29, 2024
@drprojects
Copy link
Owner

Happy that you found a workaround and thanks for the detailed feedback !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants