Skip to content

Commit 7606bdf

Browse files
authored
Merge pull request #714 from rwightman/vit_and_bit_test_fixes
Fix a few issues loading pretrained vit/bit npz weights...
2 parents dc42282 + 20a2be1 commit 7606bdf

File tree

8 files changed

+44
-12
lines changed

8 files changed

+44
-12
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### June 23, 2021
27+
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050).
28+
2629
### June 20, 2021
2730
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
2831
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)

tests/test_models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
147147
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
148148
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
149149

150+
if 'pruned' not in model_name: # FIXME better pruned model handling
151+
# test classifier + global pool deletion via __init__
152+
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
153+
outputs = model.forward(input_tensor)
154+
assert len(outputs.shape) == 4
155+
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
156+
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
157+
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
158+
150159
# check classifier name matches default_cfg
151160
classifier = cfg['classifier']
152161
if not isinstance(classifier, (tuple, list)):
@@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
193202
assert len(outputs.shape) == 2
194203
assert outputs.shape[1] == model.num_features
195204

205+
model = create_model(model_name, pretrained=False, num_classes=0).eval()
206+
outputs = model.forward(input_tensor)
207+
if isinstance(outputs, tuple):
208+
outputs = outputs[0]
209+
assert len(outputs.shape) == 2
210+
assert outputs.shape[1] == model.num_features
211+
196212
# check classifier name matches default_cfg
197213
classifier = cfg['classifier']
198214
if not isinstance(classifier, (tuple, list)):
@@ -217,6 +233,7 @@ def test_model_load_pretrained(model_name, batch_size):
217233
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
218234
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
219235
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
236+
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)
220237

221238
@pytest.mark.timeout(120)
222239
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, o
182182
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
183183
self.act2 = nn.ReLU(inplace=True)
184184
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
185-
self.classifier = Linear(out_chs, num_classes)
185+
self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
186186

187187
def get_classifier(self):
188188
return self.classifier

timm/models/levit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model):
542542
state_dict = state_dict['model']
543543
D = model.state_dict()
544544
for k in state_dict.keys():
545-
if D[k].ndim == 4 and state_dict[k].ndim == 2:
545+
if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
546546
state_dict[k] = state_dict[k][:, :, None, None]
547547
return state_dict
548548

timm/models/mlp_mixer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def _cfg(url='', **kwargs):
129129
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
130130

131131
gmlp_ti16_224=_cfg(),
132-
gmlp_s16_224=_cfg(),
132+
gmlp_s16_224=_cfg(
133+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
134+
),
133135
gmlp_b16_224=_cfg(),
134136
)
135137

@@ -266,7 +268,7 @@ def __init__(
266268
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
267269
for _ in range(num_blocks)])
268270
self.norm = norm_layer(embed_dim)
269-
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
271+
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
270272

271273
self.init_weights(nlhb=nlhb)
272274

timm/models/resnetv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ def t2p(conv_weights):
424424
model.stem.conv.weight.copy_(stem_conv_w)
425425
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
426426
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
427-
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
427+
if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \
428+
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
428429
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
429430
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
430431
for i, (sname, stage) in enumerate(model.stages.named_children()):

timm/models/visformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, in
237237
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
238238
self.norm = norm_layer(self.num_features)
239239
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
240-
self.head = nn.Linear(self.num_features, num_classes)
241240

242241
# weights init
243242
if self.pos_embed:

timm/models/vision_transformer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
- https://arxiv.org/abs/2010.11929
77
88
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9-
- https://arxiv.org/abs/2106.TODO
9+
- https://arxiv.org/abs/2106.10270
1010
1111
The official jax code is released and available at https://github.com/google-research/vision_transformer
1212
@@ -448,9 +448,12 @@ def _n2p(w, t=True):
448448
model.pos_embed.copy_(pos_embed_w)
449449
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
450450
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
451-
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
451+
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
452452
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
453453
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
454+
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
455+
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
456+
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
454457
for i, block in enumerate(model.blocks.children()):
455458
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
456459
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
673676
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
674677
""" ViT-Tiny (Vit-Ti/16).
675678
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
679+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
676680
"""
677681
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
678682
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
683687
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
684688
""" ViT-Small (ViT-S/16)
685689
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
690+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
686691
"""
687692
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
688693
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
@@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
693698
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
694699
""" ViT-Small (ViT-S/16)
695700
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
701+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
696702
"""
697703
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
698704
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
703709
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
704710
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
705711
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
712+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
706713
"""
707714
model_kwargs = dict(
708-
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
715+
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
709716
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
710717
return model
711718

@@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
714721
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
715722
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
716723
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
724+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
717725
"""
718726
model_kwargs = dict(
719-
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
727+
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
720728
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
721729
return model
722730

@@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
725733
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
726734
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
727735
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
736+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
728737
"""
729738
model_kwargs = dict(
730739
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
@@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
736745
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
737746
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
738747
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
748+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
739749
"""
740750
model_kwargs = dict(
741-
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
751+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
742752
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
743753
return model
744754

@@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
747757
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
748758
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
749759
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
750-
NOTE: converted weights not currently available, too large for github release hosting.
760+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
751761
"""
752762
model_kwargs = dict(
753763
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)

0 commit comments

Comments
 (0)