Skip to content

Commit 0d253e2

Browse files
committed
Fix issue with nfnet tests, bit more cleanup.
1 parent cb06c7a commit 0d253e2

File tree

2 files changed

+24
-33
lines changed

2 files changed

+24
-33
lines changed

tests/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
# exclude models that cause specific test failures
2020
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
2121
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
22-
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS
22+
EXCLUDE_FILTERS = [
23+
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
24+
'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS
2325
else:
2426
EXCLUDE_FILTERS = NON_STD_FILTERS
2527

timm/models/nfnet.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
44
- https://arxiv.org/abs/2101.08692
55
6-
76
Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
87
- https://arxiv.org/abs/2102.06171
98
109
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
1110
1211
Status:
1312
* These models are a work in progress, experiments ongoing.
14-
* Two pretrained weights so far, more to come.
15-
* Model details update to closer match official JAX code now that it's released
13+
* Pretrained weights for two models so far, more to come.
14+
* Model details updated to closer match official JAX code now that it's released
1615
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
1716
1817
Hacked together by / copyright Ross Wightman, 2021.
@@ -150,7 +149,7 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
150149
num_features = channels[-1] * 2
151150
attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8)
152151
cfg = NfCfg(
153-
depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True,
152+
depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True,
154153
num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
155154
return cfg
156155

@@ -176,9 +175,6 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
176175
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
177176
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
178177

179-
# NFNet-F models w/ SiLU (much faster in PyTorch)
180-
# FIXME add remainder if silu vs gelu proves worthwhile
181-
182178
# EffNet influenced RegNet defs.
183179
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
184180
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
@@ -194,9 +190,9 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
194190
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
195191
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
196192

197-
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
198-
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
199-
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
193+
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
194+
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
195+
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
200196

201197
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
202198
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
@@ -315,38 +311,26 @@ def forward(self, x):
315311
return out
316312

317313

318-
def stem_info(stem_type):
319-
stem_stride = 2
320-
if 'nff' in stem_type or 'pool' in stem_type:
321-
stem_stride = 4
322-
stem_feat = ''
323-
if 'nff' in stem_type:
324-
stem_feat = 'stem.act3'
325-
elif 'deep' in stem_type and not 'pool' in stem_type:
326-
stem_feat = 'stem.act2'
327-
return stem_stride, stem_feat
328-
329-
330314
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
331315
stem_stride = 2
332-
stem_feature = ''
316+
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
333317
stem = OrderedDict()
334-
assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
318+
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
335319
if 'deep' in stem_type or 'nff' in stem_type:
336320
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
337-
if 'nff' in stem_type:
321+
if 'quad' in stem_type:
338322
assert not 'pool' in stem_type
339323
stem_chs = (16, 32, 64, out_chs)
340324
strides = (2, 1, 1, 2)
341325
stem_stride = 4
342-
stem_feature = 'stem.act4'
326+
stem_feature = dict(num_chs=64, reduction=2, module='stem.act4')
343327
else:
344328
if 'tiered' in stem_type:
345-
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs)
329+
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py
346330
else:
347-
stem_chs = (out_chs // 2, out_chs // 2, out_chs)
331+
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
348332
strides = (2, 1, 1)
349-
stem_feature = 'stem.act3'
333+
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
350334
last_idx = len(stem_chs) - 1
351335
for i, (c, s) in enumerate(zip(stem_chs, strides)):
352336
stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
@@ -401,7 +385,7 @@ class NormFreeNet(nn.Module):
401385
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
402386
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
403387
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
404-
apply it in each activation. This is slightly slower, and yields slightly different results.
388+
apply it in each activation. This is slightly slower, numerically different, but matches official impl.
405389
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
406390
for what it is/does. Approx 8-10% throughput loss.
407391
"""
@@ -424,7 +408,7 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
424408
self.stem, stem_stride, stem_feat = create_stem(
425409
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
426410

427-
self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else []
411+
self.feature_info = [stem_feat] if stem_stride == 4 else []
428412
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
429413
prev_chs = stem_chs
430414
net_stride = stem_stride
@@ -476,7 +460,6 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
476460
# The paper NFRegNet models have an EfficientNet-like final head convolution.
477461
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
478462
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
479-
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv
480463
else:
481464
self.num_features = prev_chs
482465
self.final_conv = nn.Identity()
@@ -554,10 +537,12 @@ def nfnet_f3(pretrained=False, **kwargs):
554537
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
555538

556539

540+
@register_model
557541
def nfnet_f4(pretrained=False, **kwargs):
558542
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
559543

560544

545+
@register_model
561546
def nfnet_f5(pretrained=False, **kwargs):
562547
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
563548

@@ -567,6 +552,7 @@ def nfnet_f6(pretrained=False, **kwargs):
567552
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
568553

569554

555+
@register_model
570556
def nfnet_f7(pretrained=False, **kwargs):
571557
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
572558

@@ -591,10 +577,12 @@ def nfnet_f3s(pretrained=False, **kwargs):
591577
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
592578

593579

580+
@register_model
594581
def nfnet_f4s(pretrained=False, **kwargs):
595582
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
596583

597584

585+
@register_model
598586
def nfnet_f5s(pretrained=False, **kwargs):
599587
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
600588

@@ -604,6 +592,7 @@ def nfnet_f6s(pretrained=False, **kwargs):
604592
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
605593

606594

595+
@register_model
607596
def nfnet_f7s(pretrained=False, **kwargs):
608597
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
609598

0 commit comments

Comments
 (0)