Skip to content

Commit e9d2ec4

Browse files
authored
Merge pull request #31 from rwightman/opt
Optimizers and more
2 parents 81875d5 + fac58f6 commit e9d2ec4

File tree

13 files changed

+550
-40
lines changed

13 files changed

+550
-40
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ The work of many others is present here. I've tried to make sure all source mate
1313
* [Myself](https://github.com/rwightman/pytorch-dpn-pretrained)
1414
* LR scheduler ideas from [AllenNLP](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers), [FAIRseq](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler), and SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983)
1515
* Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896)
16-
16+
* Optimizers:
17+
* RAdam by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) (https://arxiv.org/abs/1908.03265)
18+
* NovoGrad by [Masashi Kimura](https://github.com/convergence-lab/novograd) (https://arxiv.org/abs/1905.11286)
19+
* Lookahead adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610)
1720
## Models
1821

1922
I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors.

timm/data/loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self,
2020
loader,
2121
rand_erase_prob=0.,
2222
rand_erase_mode='const',
23+
rand_erase_count=1,
2324
mean=IMAGENET_DEFAULT_MEAN,
2425
std=IMAGENET_DEFAULT_STD,
2526
fp16=False):
@@ -32,7 +33,7 @@ def __init__(self,
3233
self.std = self.std.half()
3334
if rand_erase_prob > 0.:
3435
self.random_erasing = RandomErasing(
35-
probability=rand_erase_prob, mode=rand_erase_mode)
36+
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
3637
else:
3738
self.random_erasing = None
3839

@@ -135,6 +136,7 @@ def create_loader(
135136
use_prefetcher=True,
136137
rand_erase_prob=0.,
137138
rand_erase_mode='const',
139+
rand_erase_count=1,
138140
color_jitter=0.4,
139141
interpolation='bilinear',
140142
mean=IMAGENET_DEFAULT_MEAN,
@@ -184,6 +186,7 @@ def create_loader(
184186
loader,
185187
rand_erase_prob=rand_erase_prob if is_training else 0.,
186188
rand_erase_mode=rand_erase_mode,
189+
rand_erase_count=rand_erase_count,
187190
mean=mean,
188191
std=std,
189192
fp16=fp16)

timm/data/random_erasing.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,20 @@ class RandomErasing:
3333
'const' - erase block is constant color of 0 for all channels
3434
'rand' - erase block is same per-cannel random (normal) color
3535
'pixel' - erase block is per-pixel random (normal) color
36+
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
37+
per-image count is randomly chosen between 1 and this value.
3638
"""
3739

3840
def __init__(
3941
self,
4042
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
41-
mode='const', device='cuda'):
43+
mode='const', max_count=1, device='cuda'):
4244
self.probability = probability
4345
self.sl = sl
4446
self.sh = sh
4547
self.min_aspect = min_aspect
48+
self.min_count = 1
49+
self.max_count = max_count
4650
mode = mode.lower()
4751
self.rand_color = False
4852
self.per_pixel = False
@@ -58,18 +62,22 @@ def _erase(self, img, chan, img_h, img_w, dtype):
5862
if random.random() > self.probability:
5963
return
6064
area = img_h * img_w
61-
for attempt in range(100):
62-
target_area = random.uniform(self.sl, self.sh) * area
63-
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
64-
h = int(round(math.sqrt(target_area * aspect_ratio)))
65-
w = int(round(math.sqrt(target_area / aspect_ratio)))
66-
if w < img_w and h < img_h:
67-
top = random.randint(0, img_h - h)
68-
left = random.randint(0, img_w - w)
69-
img[:, top:top + h, left:left + w] = _get_pixels(
70-
self.per_pixel, self.rand_color, (chan, h, w),
71-
dtype=dtype, device=self.device)
72-
break
65+
count = self.min_count if self.min_count == self.max_count else \
66+
random.randint(self.min_count, self.max_count)
67+
for _ in range(count):
68+
for attempt in range(10):
69+
target_area = random.uniform(self.sl, self.sh) * area / count
70+
log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect))
71+
aspect_ratio = math.exp(random.uniform(*log_ratio))
72+
h = int(round(math.sqrt(target_area * aspect_ratio)))
73+
w = int(round(math.sqrt(target_area / aspect_ratio)))
74+
if w < img_w and h < img_h:
75+
top = random.randint(0, img_h - h)
76+
left = random.randint(0, img_w - w)
77+
img[:, top:top + h, left:left + w] = _get_pixels(
78+
self.per_pixel, self.rand_color, (chan, h, w),
79+
dtype=dtype, device=self.device)
80+
break
7381

7482
def __call__(self, input):
7583
if len(input.size()) == 3:

timm/data/transforms.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,31 @@ def get_params(img, scale, ratio):
107107

108108
for attempt in range(10):
109109
target_area = random.uniform(*scale) * area
110-
aspect_ratio = random.uniform(*ratio)
110+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
111+
aspect_ratio = math.exp(random.uniform(*log_ratio))
111112

112113
w = int(round(math.sqrt(target_area * aspect_ratio)))
113114
h = int(round(math.sqrt(target_area / aspect_ratio)))
114115

115-
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
116-
w, h = h, w
117-
118116
if w <= img.size[0] and h <= img.size[1]:
119117
i = random.randint(0, img.size[1] - h)
120118
j = random.randint(0, img.size[0] - w)
121119
return i, j, h, w
122120

123-
# Fallback
124-
w = min(img.size[0], img.size[1])
125-
i = (img.size[1] - w) // 2
121+
# Fallback to central crop
122+
in_ratio = img.size[0] / img.size[1]
123+
if in_ratio < min(ratio):
124+
w = img.size[0]
125+
h = int(round(w / min(ratio)))
126+
elif in_ratio > max(ratio):
127+
h = img.size[1]
128+
w = int(round(h * max(ratio)))
129+
else: # whole image
130+
w = img.size[0]
131+
h = img.size[1]
132+
i = (img.size[1] - h) // 2
126133
j = (img.size[0] - w) // 2
127-
return i, j, w, w
134+
return i, j, h, w
128135

129136
def __call__(self, img):
130137
"""

timm/models/resnet.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def _cfg(url='', **kwargs):
4444
'resnet50': _cfg(
4545
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
4646
interpolation='bicubic'),
47+
'resnet50d': _cfg(
48+
url='',
49+
interpolation='bicubic'),
4750
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
4851
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
4952
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
@@ -259,7 +262,7 @@ class ResNet(nn.Module):
259262
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
260263
cardinality=1, base_width=64, stem_width=64, deep_stem=False,
261264
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
262-
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'):
265+
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', zero_init_last_bn=True):
263266
self.num_classes = num_classes
264267
self.inplanes = stem_width * 2 if deep_stem else 64
265268
self.cardinality = cardinality
@@ -296,11 +299,16 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
296299
self.num_features = 512 * block.expansion
297300
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
298301

299-
for m in self.modules():
302+
last_bn_name = 'bn3' if 'Bottleneck' in block.__name__ else 'bn2'
303+
for n, m in self.named_modules():
300304
if isinstance(m, nn.Conv2d):
301305
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
302306
elif isinstance(m, nn.BatchNorm2d):
303-
nn.init.constant_(m.weight, 1.)
307+
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
308+
# Initialize weight/gamma of last BN in each residual block to zero
309+
nn.init.constant_(m.weight, 0.)
310+
else:
311+
nn.init.constant_(m.weight, 1.)
304312
nn.init.constant_(m.bias, 0.)
305313

306314
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
@@ -434,6 +442,20 @@ def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
434442
return model
435443

436444

445+
@register_model
446+
def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
447+
"""Constructs a ResNet-50-D model.
448+
"""
449+
default_cfg = default_cfgs['resnet50d']
450+
model = ResNet(
451+
Bottleneck, [3, 4, 6, 3], stem_width=32, deep_stem=True, avg_down=True,
452+
num_classes=num_classes, in_chans=in_chans, **kwargs)
453+
model.default_cfg = default_cfg
454+
if pretrained:
455+
load_pretrained(model, default_cfg, num_classes, in_chans)
456+
return model
457+
458+
437459
@register_model
438460
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
439461
"""Constructs a ResNet-101 model.

timm/optim/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from .nadam import Nadam
22
from .rmsprop_tf import RMSpropTF
3+
from .adamw import AdamW
4+
from .radam import RAdam
5+
from .novograd import NovoGrad
6+
from .lookahead import Lookahead
37
from .optim_factory import create_optimizer

timm/optim/adamw.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
""" AdamW Optimizer
2+
Impl copied from PyTorch master
3+
"""
4+
import math
5+
import torch
6+
from torch.optim.optimizer import Optimizer
7+
8+
9+
class AdamW(Optimizer):
10+
r"""Implements AdamW algorithm.
11+
12+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14+
15+
Arguments:
16+
params (iterable): iterable of parameters to optimize or dicts defining
17+
parameter groups
18+
lr (float, optional): learning rate (default: 1e-3)
19+
betas (Tuple[float, float], optional): coefficients used for computing
20+
running averages of gradient and its square (default: (0.9, 0.999))
21+
eps (float, optional): term added to the denominator to improve
22+
numerical stability (default: 1e-8)
23+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25+
algorithm from the paper `On the Convergence of Adam and Beyond`_
26+
(default: False)
27+
28+
.. _Adam\: A Method for Stochastic Optimization:
29+
https://arxiv.org/abs/1412.6980
30+
.. _Decoupled Weight Decay Regularization:
31+
https://arxiv.org/abs/1711.05101
32+
.. _On the Convergence of Adam and Beyond:
33+
https://openreview.net/forum?id=ryQu7f-RZ
34+
"""
35+
36+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
37+
weight_decay=1e-2, amsgrad=False):
38+
if not 0.0 <= lr:
39+
raise ValueError("Invalid learning rate: {}".format(lr))
40+
if not 0.0 <= eps:
41+
raise ValueError("Invalid epsilon value: {}".format(eps))
42+
if not 0.0 <= betas[0] < 1.0:
43+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44+
if not 0.0 <= betas[1] < 1.0:
45+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46+
defaults = dict(lr=lr, betas=betas, eps=eps,
47+
weight_decay=weight_decay, amsgrad=amsgrad)
48+
super(AdamW, self).__init__(params, defaults)
49+
50+
def __setstate__(self, state):
51+
super(AdamW, self).__setstate__(state)
52+
for group in self.param_groups:
53+
group.setdefault('amsgrad', False)
54+
55+
def step(self, closure=None):
56+
"""Performs a single optimization step.
57+
58+
Arguments:
59+
closure (callable, optional): A closure that reevaluates the model
60+
and returns the loss.
61+
"""
62+
loss = None
63+
if closure is not None:
64+
loss = closure()
65+
66+
for group in self.param_groups:
67+
for p in group['params']:
68+
if p.grad is None:
69+
continue
70+
71+
# Perform stepweight decay
72+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
73+
74+
# Perform optimization step
75+
grad = p.grad.data
76+
if grad.is_sparse:
77+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78+
amsgrad = group['amsgrad']
79+
80+
state = self.state[p]
81+
82+
# State initialization
83+
if len(state) == 0:
84+
state['step'] = 0
85+
# Exponential moving average of gradient values
86+
state['exp_avg'] = torch.zeros_like(p.data)
87+
# Exponential moving average of squared gradient values
88+
state['exp_avg_sq'] = torch.zeros_like(p.data)
89+
if amsgrad:
90+
# Maintains max of all exp. moving avg. of sq. grad. values
91+
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92+
93+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94+
if amsgrad:
95+
max_exp_avg_sq = state['max_exp_avg_sq']
96+
beta1, beta2 = group['betas']
97+
98+
state['step'] += 1
99+
bias_correction1 = 1 - beta1 ** state['step']
100+
bias_correction2 = 1 - beta2 ** state['step']
101+
102+
# Decay the first and second moment running average coefficient
103+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
104+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
105+
if amsgrad:
106+
# Maintains the maximum of all 2nd moment running avg. till now
107+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108+
# Use the max. for normalizing running avg. of gradient
109+
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
110+
else:
111+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
112+
113+
step_size = group['lr'] / bias_correction1
114+
115+
p.data.addcdiv_(-step_size, exp_avg, denom)
116+
117+
return loss

0 commit comments

Comments
 (0)