-
Notifications
You must be signed in to change notification settings - Fork 0
/
learning_rate.py
96 lines (82 loc) · 3.67 KB
/
learning_rate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import math
from bisect import bisect_right
from functools import partial
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def __getstate__(self):
return self.state_dict()
def __setstate__(self, state):
self.load_state_dict(state)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
print("New lr", lr)
param_group['lr'] = lr
class gammaLR(_LRScheduler):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer has two groups.
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> scheduler.step()
>>> train(...)
>>> validate(...)
"""
def __init__(self, optimizer, gamma, last_epoch=-1):
self.optimizer = optimizer
# if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
# self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
# else:
# if len(lr_lambda) != len(optimizer.param_groups):
# raise ValueError("Expected {} lr_lambdas, but got {}".format(
# len(optimizer.param_groups), len(lr_lambda)))
# self.lr_lambdas = list(lr_lambda)
self.last_epoch = last_epoch
self.gamma = gamma
super(gammaLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
print([base_lr -self.gamma for base_lr in self.base_lrs])
return [(base_lr['lr'] - self.gamma)
for base_lr in self.optimizer.param_groups]