Skip to content

Commit 64966f6

Browse files
committed
Add Nvidia's NovogGrad impl from Jasper (cleaner/faster than current) and Apex Fused optimizers
1 parent 3d9c8a6 commit 64966f6

File tree

3 files changed

+150
-6
lines changed

3 files changed

+150
-6
lines changed

timm/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from .adamw import AdamW
44
from .radam import RAdam
55
from .novograd import NovoGrad
6+
from .nvnovograd import NvNovoGrad
67
from .lookahead import Lookahead
78
from .optim_factory import create_optimizer

timm/optim/nvnovograd.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
""" Nvidia NovoGrad Optimizer.
2+
Original impl by Nvidia from Jasper example:
3+
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5+
- https://arxiv.org/abs/1905.11286
6+
"""
7+
8+
import torch
9+
from torch.optim.optimizer import Optimizer
10+
import math
11+
12+
13+
class NvNovoGrad(Optimizer):
14+
"""
15+
Implements Novograd algorithm.
16+
17+
Args:
18+
params (iterable): iterable of parameters to optimize or dicts defining
19+
parameter groups
20+
lr (float, optional): learning rate (default: 1e-3)
21+
betas (Tuple[float, float], optional): coefficients used for computing
22+
running averages of gradient and its square (default: (0.95, 0.98))
23+
eps (float, optional): term added to the denominator to improve
24+
numerical stability (default: 1e-8)
25+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26+
grad_averaging: gradient averaging
27+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28+
algorithm from the paper `On the Convergence of Adam and Beyond`_
29+
(default: False)
30+
"""
31+
32+
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
33+
weight_decay=0, grad_averaging=False, amsgrad=False):
34+
if not 0.0 <= lr:
35+
raise ValueError("Invalid learning rate: {}".format(lr))
36+
if not 0.0 <= eps:
37+
raise ValueError("Invalid epsilon value: {}".format(eps))
38+
if not 0.0 <= betas[0] < 1.0:
39+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40+
if not 0.0 <= betas[1] < 1.0:
41+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42+
defaults = dict(lr=lr, betas=betas, eps=eps,
43+
weight_decay=weight_decay,
44+
grad_averaging=grad_averaging,
45+
amsgrad=amsgrad)
46+
47+
super(NvNovoGrad, self).__init__(params, defaults)
48+
49+
def __setstate__(self, state):
50+
super(NvNovoGrad, self).__setstate__(state)
51+
for group in self.param_groups:
52+
group.setdefault('amsgrad', False)
53+
54+
def step(self, closure=None):
55+
"""Performs a single optimization step.
56+
57+
Arguments:
58+
closure (callable, optional): A closure that reevaluates the model
59+
and returns the loss.
60+
"""
61+
loss = None
62+
if closure is not None:
63+
loss = closure()
64+
65+
for group in self.param_groups:
66+
for p in group['params']:
67+
if p.grad is None:
68+
continue
69+
grad = p.grad.data
70+
if grad.is_sparse:
71+
raise RuntimeError('Sparse gradients are not supported.')
72+
amsgrad = group['amsgrad']
73+
74+
state = self.state[p]
75+
76+
# State initialization
77+
if len(state) == 0:
78+
state['step'] = 0
79+
# Exponential moving average of gradient values
80+
state['exp_avg'] = torch.zeros_like(p.data)
81+
# Exponential moving average of squared gradient values
82+
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
83+
if amsgrad:
84+
# Maintains max of all exp. moving avg. of sq. grad. values
85+
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
86+
87+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
88+
if amsgrad:
89+
max_exp_avg_sq = state['max_exp_avg_sq']
90+
beta1, beta2 = group['betas']
91+
92+
state['step'] += 1
93+
94+
norm = torch.sum(torch.pow(grad, 2))
95+
96+
if exp_avg_sq == 0:
97+
exp_avg_sq.copy_(norm)
98+
else:
99+
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
100+
101+
if amsgrad:
102+
# Maintains the maximum of all 2nd moment running avg. till now
103+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
104+
# Use the max. for normalizing running avg. of gradient
105+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
106+
else:
107+
denom = exp_avg_sq.sqrt().add_(group['eps'])
108+
109+
grad.div_(denom)
110+
if group['weight_decay'] != 0:
111+
grad.add_(group['weight_decay'], p.data)
112+
if group['grad_averaging']:
113+
grad.mul_(1 - beta1)
114+
exp_avg.mul_(beta1).add_(grad)
115+
116+
p.data.add_(-group['lr'], exp_avg)
117+
118+
return loss

timm/optim/optim_factory.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import torch
12
from torch import optim as optim
2-
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead
3+
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
4+
try:
5+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
6+
has_apex = True
7+
except ImportError:
8+
has_apex = False
39

410

511
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
@@ -20,22 +26,25 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
2026
def create_optimizer(args, model, filter_bias_and_bn=True):
2127
opt_lower = args.opt.lower()
2228
weight_decay = args.weight_decay
23-
if opt_lower == 'adamw' or opt_lower == 'radam':
24-
# compensate for the way current AdamW and RAdam optimizers
25-
# apply the weight-decay
29+
if 'adamw' in opt_lower or 'radam' in opt_lower:
30+
# Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
31+
# I don't believe they follow the paper or original Torch7 impl which schedules weight
32+
# decay based on the ratio of current_lr/initial_lr
2633
weight_decay /= args.lr
2734
if weight_decay and filter_bias_and_bn:
2835
parameters = add_weight_decay(model, weight_decay)
2936
weight_decay = 0.
3037
else:
3138
parameters = model.parameters()
3239

40+
if 'fused' in opt_lower:
41+
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
42+
3343
opt_split = opt_lower.split('_')
3444
opt_lower = opt_split[-1]
3545
if opt_lower == 'sgd':
3646
optimizer = optim.SGD(
37-
parameters, lr=args.lr,
38-
momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
47+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
3948
elif opt_lower == 'adam':
4049
optimizer = optim.Adam(
4150
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
@@ -61,6 +70,22 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
6170
momentum=args.momentum, weight_decay=weight_decay)
6271
elif opt_lower == 'novograd':
6372
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
73+
elif opt_lower == 'nvnovograd':
74+
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
75+
elif opt_lower == 'fusedsgd':
76+
optimizer = FusedSGD(
77+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
78+
elif opt_lower == 'fusedadam':
79+
optimizer = FusedAdam(
80+
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
81+
elif opt_lower == 'fusedadamw':
82+
optimizer = FusedAdam(
83+
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
84+
elif opt_lower == 'fusedlamb':
85+
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
86+
elif opt_lower == 'fusednovograd':
87+
optimizer = FusedNovoGrad(
88+
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps)
6489
else:
6590
assert False and "Invalid optimizer"
6691
raise ValueError

0 commit comments

Comments
 (0)