Skip to content

Commit aff194f

Browse files
authored
Merge pull request #32 from rwightman/opt
More optimizer work
2 parents 5c6da1c + 64966f6 commit aff194f

File tree

7 files changed

+229
-63
lines changed

7 files changed

+229
-63
lines changed

timm/models/helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
2929

3030

3131
def resume_checkpoint(model, checkpoint_path):
32-
optimizer_state = None
32+
other_state = {}
3333
resume_epoch = None
3434
if os.path.isfile(checkpoint_path):
3535
checkpoint = torch.load(checkpoint_path, map_location='cpu')
@@ -40,7 +40,9 @@ def resume_checkpoint(model, checkpoint_path):
4040
new_state_dict[name] = v
4141
model.load_state_dict(new_state_dict)
4242
if 'optimizer' in checkpoint:
43-
optimizer_state = checkpoint['optimizer']
43+
other_state['optimizer'] = checkpoint['optimizer']
44+
if 'amp' in checkpoint:
45+
other_state['amp'] = checkpoint['amp']
4446
if 'epoch' in checkpoint:
4547
resume_epoch = checkpoint['epoch']
4648
if 'version' in checkpoint and checkpoint['version'] > 1:
@@ -49,7 +51,7 @@ def resume_checkpoint(model, checkpoint_path):
4951
else:
5052
model.load_state_dict(checkpoint)
5153
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
52-
return optimizer_state, resume_epoch
54+
return other_state, resume_epoch
5355
else:
5456
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
5557
raise FileNotFoundError()

timm/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from .adamw import AdamW
44
from .radam import RAdam
55
from .novograd import NovoGrad
6+
from .nvnovograd import NvNovoGrad
67
from .lookahead import Lookahead
78
from .optim_factory import create_optimizer

timm/optim/lookahead.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,40 @@ def __init__(self, base_optimizer, alpha=0.5, k=6):
1313
raise ValueError(f'Invalid slow update rate: {alpha}')
1414
if not 1 <= k:
1515
raise ValueError(f'Invalid lookahead steps: {k}')
16-
self.alpha = alpha
17-
self.k = k
16+
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
1817
self.base_optimizer = base_optimizer
1918
self.param_groups = self.base_optimizer.param_groups
2019
self.defaults = base_optimizer.defaults
20+
self.defaults.update(defaults)
2121
self.state = defaultdict(dict)
22-
for group in self.param_groups:
23-
group["step_counter"] = 0
22+
# manually add our defaults to the param groups
23+
for name, default in defaults.items():
24+
for group in self.param_groups:
25+
group.setdefault(name, default)
2426

25-
def update_slow_weights(self, group):
27+
def update_slow(self, group):
2628
for fast_p in group["params"]:
2729
if fast_p.grad is None:
2830
continue
2931
param_state = self.state[fast_p]
30-
if "slow_buffer" not in param_state:
31-
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
32-
param_state["slow_buffer"].copy_(fast_p.data)
33-
slow = param_state["slow_buffer"]
34-
slow.add_(self.alpha, fast_p.data - slow)
32+
if 'slow_buffer' not in param_state:
33+
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
34+
param_state['slow_buffer'].copy_(fast_p.data)
35+
slow = param_state['slow_buffer']
36+
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
3537
fast_p.data.copy_(slow)
3638

3739
def sync_lookahead(self):
3840
for group in self.param_groups:
39-
self.update_slow_weights(group)
41+
self.update_slow(group)
4042

4143
def step(self, closure=None):
44+
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
4245
loss = self.base_optimizer.step(closure)
4346
for group in self.param_groups:
44-
group['step_counter'] += 1
45-
if group['step_counter'] % self.k == 0:
46-
self.update_slow_weights(group)
47+
group['lookahead_step'] += 1
48+
if group['lookahead_step'] % group['lookahead_k'] == 0:
49+
self.update_slow(group)
4750
return loss
4851

4952
def state_dict(self):
@@ -52,37 +55,36 @@ def state_dict(self):
5255
(id(k) if isinstance(k, torch.Tensor) else k): v
5356
for k, v in self.state.items()
5457
}
55-
fast_state = fast_state_dict["state"]
56-
param_groups = fast_state_dict["param_groups"]
58+
fast_state = fast_state_dict['state']
59+
param_groups = fast_state_dict['param_groups']
5760
return {
58-
"state": fast_state,
59-
"slow_state": slow_state,
60-
"param_groups": param_groups,
61+
'state': fast_state,
62+
'slow_state': slow_state,
63+
'param_groups': param_groups,
6164
}
6265

6366
def load_state_dict(self, state_dict):
67+
fast_state_dict = {
68+
'state': state_dict['state'],
69+
'param_groups': state_dict['param_groups'],
70+
}
71+
self.base_optimizer.load_state_dict(fast_state_dict)
72+
73+
# We want to restore the slow state, but share param_groups reference
74+
# with base_optimizer. This is a bit redundant but least code
75+
slow_state_new = False
6476
if 'slow_state' not in state_dict:
65-
print('Loading state_dict from optimizer without Lookahead applied')
77+
print('Loading state_dict from optimizer without Lookahead applied.')
6678
state_dict['slow_state'] = defaultdict(dict)
79+
slow_state_new = True
6780
slow_state_dict = {
68-
"state": state_dict["slow_state"],
69-
"param_groups": state_dict["param_groups"],
70-
}
71-
fast_state_dict = {
72-
"state": state_dict["state"],
73-
"param_groups": state_dict["param_groups"],
81+
'state': state_dict['slow_state'],
82+
'param_groups': state_dict['param_groups'], # this is pointless but saves code
7483
}
7584
super(Lookahead, self).load_state_dict(slow_state_dict)
76-
self.base_optimizer.load_state_dict(fast_state_dict)
77-
78-
def add_param_group(self, param_group):
79-
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
80-
This can be useful when fine tuning a pre-trained network as frozen
81-
layers can be made trainable and added to the :class:`Optimizer` as
82-
training progresses.
83-
Args:
84-
param_group (dict): Specifies what Tensors should be optimized along
85-
with group specific optimization options.
86-
"""
87-
param_group['step_counter'] = 0
88-
self.base_optimizer.add_param_group(param_group)
85+
self.param_groups = self.base_optimizer.param_groups # make both ref same container
86+
if slow_state_new:
87+
# reapply defaults to catch missing lookahead specific ones
88+
for name, default in self.defaults.items():
89+
for group in self.param_groups:
90+
group.setdefault(name, default)

timm/optim/nvnovograd.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
""" Nvidia NovoGrad Optimizer.
2+
Original impl by Nvidia from Jasper example:
3+
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5+
- https://arxiv.org/abs/1905.11286
6+
"""
7+
8+
import torch
9+
from torch.optim.optimizer import Optimizer
10+
import math
11+
12+
13+
class NvNovoGrad(Optimizer):
14+
"""
15+
Implements Novograd algorithm.
16+
17+
Args:
18+
params (iterable): iterable of parameters to optimize or dicts defining
19+
parameter groups
20+
lr (float, optional): learning rate (default: 1e-3)
21+
betas (Tuple[float, float], optional): coefficients used for computing
22+
running averages of gradient and its square (default: (0.95, 0.98))
23+
eps (float, optional): term added to the denominator to improve
24+
numerical stability (default: 1e-8)
25+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26+
grad_averaging: gradient averaging
27+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28+
algorithm from the paper `On the Convergence of Adam and Beyond`_
29+
(default: False)
30+
"""
31+
32+
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
33+
weight_decay=0, grad_averaging=False, amsgrad=False):
34+
if not 0.0 <= lr:
35+
raise ValueError("Invalid learning rate: {}".format(lr))
36+
if not 0.0 <= eps:
37+
raise ValueError("Invalid epsilon value: {}".format(eps))
38+
if not 0.0 <= betas[0] < 1.0:
39+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40+
if not 0.0 <= betas[1] < 1.0:
41+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42+
defaults = dict(lr=lr, betas=betas, eps=eps,
43+
weight_decay=weight_decay,
44+
grad_averaging=grad_averaging,
45+
amsgrad=amsgrad)
46+
47+
super(NvNovoGrad, self).__init__(params, defaults)
48+
49+
def __setstate__(self, state):
50+
super(NvNovoGrad, self).__setstate__(state)
51+
for group in self.param_groups:
52+
group.setdefault('amsgrad', False)
53+
54+
def step(self, closure=None):
55+
"""Performs a single optimization step.
56+
57+
Arguments:
58+
closure (callable, optional): A closure that reevaluates the model
59+
and returns the loss.
60+
"""
61+
loss = None
62+
if closure is not None:
63+
loss = closure()
64+
65+
for group in self.param_groups:
66+
for p in group['params']:
67+
if p.grad is None:
68+
continue
69+
grad = p.grad.data
70+
if grad.is_sparse:
71+
raise RuntimeError('Sparse gradients are not supported.')
72+
amsgrad = group['amsgrad']
73+
74+
state = self.state[p]
75+
76+
# State initialization
77+
if len(state) == 0:
78+
state['step'] = 0
79+
# Exponential moving average of gradient values
80+
state['exp_avg'] = torch.zeros_like(p.data)
81+
# Exponential moving average of squared gradient values
82+
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
83+
if amsgrad:
84+
# Maintains max of all exp. moving avg. of sq. grad. values
85+
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
86+
87+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
88+
if amsgrad:
89+
max_exp_avg_sq = state['max_exp_avg_sq']
90+
beta1, beta2 = group['betas']
91+
92+
state['step'] += 1
93+
94+
norm = torch.sum(torch.pow(grad, 2))
95+
96+
if exp_avg_sq == 0:
97+
exp_avg_sq.copy_(norm)
98+
else:
99+
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
100+
101+
if amsgrad:
102+
# Maintains the maximum of all 2nd moment running avg. till now
103+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
104+
# Use the max. for normalizing running avg. of gradient
105+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
106+
else:
107+
denom = exp_avg_sq.sqrt().add_(group['eps'])
108+
109+
grad.div_(denom)
110+
if group['weight_decay'] != 0:
111+
grad.add_(group['weight_decay'], p.data)
112+
if group['grad_averaging']:
113+
grad.mul_(1 - beta1)
114+
exp_avg.mul_(beta1).add_(grad)
115+
116+
p.data.add_(-group['lr'], exp_avg)
117+
118+
return loss

timm/optim/optim_factory.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import torch
12
from torch import optim as optim
2-
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead
3+
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
4+
try:
5+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
6+
has_apex = True
7+
except ImportError:
8+
has_apex = False
39

410

511
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
@@ -20,22 +26,25 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
2026
def create_optimizer(args, model, filter_bias_and_bn=True):
2127
opt_lower = args.opt.lower()
2228
weight_decay = args.weight_decay
23-
if opt_lower == 'adamw' or opt_lower == 'radam':
24-
# compensate for the way current AdamW and RAdam optimizers
25-
# apply the weight-decay
29+
if 'adamw' in opt_lower or 'radam' in opt_lower:
30+
# Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
31+
# I don't believe they follow the paper or original Torch7 impl which schedules weight
32+
# decay based on the ratio of current_lr/initial_lr
2633
weight_decay /= args.lr
2734
if weight_decay and filter_bias_and_bn:
2835
parameters = add_weight_decay(model, weight_decay)
2936
weight_decay = 0.
3037
else:
3138
parameters = model.parameters()
3239

40+
if 'fused' in opt_lower:
41+
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
42+
3343
opt_split = opt_lower.split('_')
3444
opt_lower = opt_split[-1]
3545
if opt_lower == 'sgd':
3646
optimizer = optim.SGD(
37-
parameters, lr=args.lr,
38-
momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
47+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
3948
elif opt_lower == 'adam':
4049
optimizer = optim.Adam(
4150
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
@@ -61,6 +70,22 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
6170
momentum=args.momentum, weight_decay=weight_decay)
6271
elif opt_lower == 'novograd':
6372
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
73+
elif opt_lower == 'nvnovograd':
74+
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
75+
elif opt_lower == 'fusedsgd':
76+
optimizer = FusedSGD(
77+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
78+
elif opt_lower == 'fusedadam':
79+
optimizer = FusedAdam(
80+
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
81+
elif opt_lower == 'fusedadamw':
82+
optimizer = FusedAdam(
83+
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
84+
elif opt_lower == 'fusedlamb':
85+
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
86+
elif opt_lower == 'fusednovograd':
87+
optimizer = FusedNovoGrad(
88+
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps)
6489
else:
6590
assert False and "Invalid optimizer"
6691
raise ValueError

0 commit comments

Comments
 (0)