Skip to content

Commit 85f894e

Browse files
committed
Fix ViT in21k representation (pre_logits) layer handling across old and new npz checkpoints
1 parent b41cffa commit 85f894e

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

timm/models/resnetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def t2p(conv_weights):
424424
model.stem.conv.weight.copy_(stem_conv_w)
425425
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
426426
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
427-
if isinstance(model.head.fc, nn.Conv2d) and \
427+
if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \
428428
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
429429
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
430430
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))

timm/models/vision_transformer.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
- https://arxiv.org/abs/2010.11929
77
88
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9-
- https://arxiv.org/abs/2106.TODO
9+
- https://arxiv.org/abs/2106.10270
1010
1111
The official jax code is released and available at https://github.com/google-research/vision_transformer
1212
@@ -451,6 +451,9 @@ def _n2p(w, t=True):
451451
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
452452
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
453453
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
454+
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
455+
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
456+
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
454457
for i, block in enumerate(model.blocks.children()):
455458
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
456459
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
673676
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
674677
""" ViT-Tiny (Vit-Ti/16).
675678
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
679+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
676680
"""
677681
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
678682
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
683687
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
684688
""" ViT-Small (ViT-S/16)
685689
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
690+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
686691
"""
687692
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
688693
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
@@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
693698
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
694699
""" ViT-Small (ViT-S/16)
695700
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
701+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
696702
"""
697703
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
698704
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
703709
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
704710
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
705711
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
712+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
706713
"""
707714
model_kwargs = dict(
708-
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
715+
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
709716
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
710717
return model
711718

@@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
714721
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
715722
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
716723
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
724+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
717725
"""
718726
model_kwargs = dict(
719-
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
727+
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
720728
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
721729
return model
722730

@@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
725733
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
726734
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
727735
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
736+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
728737
"""
729738
model_kwargs = dict(
730739
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
@@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
736745
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
737746
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
738747
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
748+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
739749
"""
740750
model_kwargs = dict(
741-
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
751+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
742752
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
743753
return model
744754

@@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
747757
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
748758
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
749759
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
750-
NOTE: converted weights not currently available, too large for github release hosting.
760+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
751761
"""
752762
model_kwargs = dict(
753763
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)

0 commit comments

Comments
 (0)