2929import torch .nn as nn
3030
3131from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
32- from .helpers import build_model_with_cfg
32+ from .helpers import load_pretrained
3333from .layers import DropPath , to_2tuple , trunc_normal_
3434from .resnet import resnet26d , resnet50d
3535from .registry import register_model
@@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
4848
4949default_cfgs = {
5050 # patch models
51- 'vit_small_patch16_224' : _cfg (),
51+ 'vit_small_patch16_224' : _cfg (
52+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth' ,
53+ ),
5254 'vit_base_patch16_224' : _cfg (),
5355 'vit_base_patch16_384' : _cfg (input_size = (3 , 384 , 384 )),
5456 'vit_base_patch32_384' : _cfg (input_size = (3 , 384 , 384 )),
@@ -271,6 +273,9 @@ def forward(self, x, attn_mask=None):
271273def vit_small_patch16_224 (pretrained = False , ** kwargs ):
272274 model = VisionTransformer (patch_size = 16 , embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3. , ** kwargs )
273275 model .default_cfg = default_cfgs ['vit_small_patch16_224' ]
276+ if pretrained :
277+ load_pretrained (
278+ model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
274279 return model
275280
276281
0 commit comments