66 - https://arxiv.org/abs/2010.11929
77
88`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9- - https://arxiv.org/abs/2106.TODO
9+ - https://arxiv.org/abs/2106.10270
1010
1111The official jax code is released and available at https://github.com/google-research/vision_transformer
1212
@@ -451,6 +451,9 @@ def _n2p(w, t=True):
451451 if isinstance (model .head , nn .Linear ) and model .head .bias .shape [0 ] == w [f'{ prefix } head/bias' ].shape [- 1 ]:
452452 model .head .weight .copy_ (_n2p (w [f'{ prefix } head/kernel' ]))
453453 model .head .bias .copy_ (_n2p (w [f'{ prefix } head/bias' ]))
454+ if isinstance (getattr (model .pre_logits , 'fc' , None ), nn .Linear ) and f'{ prefix } pre_logits/bias' in w :
455+ model .pre_logits .fc .weight .copy_ (_n2p (w [f'{ prefix } pre_logits/kernel' ]))
456+ model .pre_logits .fc .bias .copy_ (_n2p (w [f'{ prefix } pre_logits/bias' ]))
454457 for i , block in enumerate (model .blocks .children ()):
455458 block_prefix = f'{ prefix } Transformer/encoderblock_{ i } /'
456459 mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
673676def vit_tiny_patch16_224_in21k (pretrained = False , ** kwargs ):
674677 """ ViT-Tiny (Vit-Ti/16).
675678 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
679+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
676680 """
677681 model_kwargs = dict (patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , ** kwargs )
678682 model = _create_vision_transformer ('vit_tiny_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
683687def vit_small_patch32_224_in21k (pretrained = False , ** kwargs ):
684688 """ ViT-Small (ViT-S/16)
685689 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
690+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
686691 """
687692 model_kwargs = dict (patch_size = 32 , embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
688693 model = _create_vision_transformer ('vit_small_patch32_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
693698def vit_small_patch16_224_in21k (pretrained = False , ** kwargs ):
694699 """ ViT-Small (ViT-S/16)
695700 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
701+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
696702 """
697703 model_kwargs = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
698704 model = _create_vision_transformer ('vit_small_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
703709def vit_base_patch32_224_in21k (pretrained = False , ** kwargs ):
704710 """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
705711 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
712+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
706713 """
707714 model_kwargs = dict (
708- patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , ** kwargs )
715+ patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
709716 model = _create_vision_transformer ('vit_base_patch32_224_in21k' , pretrained = pretrained , ** model_kwargs )
710717 return model
711718
@@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
714721def vit_base_patch16_224_in21k (pretrained = False , ** kwargs ):
715722 """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
716723 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
724+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
717725 """
718726 model_kwargs = dict (
719- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , ** kwargs )
727+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
720728 model = _create_vision_transformer ('vit_base_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
721729 return model
722730
@@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
725733def vit_large_patch32_224_in21k (pretrained = False , ** kwargs ):
726734 """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
727735 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
736+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
728737 """
729738 model_kwargs = dict (
730739 patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , representation_size = 1024 , ** kwargs )
@@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
736745def vit_large_patch16_224_in21k (pretrained = False , ** kwargs ):
737746 """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
738747 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
748+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
739749 """
740750 model_kwargs = dict (
741- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , representation_size = 1024 , ** kwargs )
751+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , ** kwargs )
742752 model = _create_vision_transformer ('vit_large_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
743753 return model
744754
@@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
747757def vit_huge_patch14_224_in21k (pretrained = False , ** kwargs ):
748758 """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
749759 ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
750- NOTE: converted weights not currently available, too large for github release hosting.
760+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
751761 """
752762 model_kwargs = dict (
753763 patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , representation_size = 1280 , ** kwargs )
0 commit comments