@@ -70,7 +70,7 @@ def __init__(
7070 self .dim = dim
7171
7272 # FIXME not working, bn layer outputs are incorrect
73- '''
73+
7474 self .conv_q = ConvNormAct (
7575 dim ,
7676 dim ,
@@ -143,7 +143,8 @@ def __init__(
143143 groups=dim
144144 )),
145145 ('bn', nn.BatchNorm2d(dim)),]))
146-
146+ '''
147+
147148 def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
148149 B , C , H , W = x .shape
149150 # [B, C, H, W] -> [B, H*W, C]
@@ -170,7 +171,7 @@ def __init__(
170171 self .num_heads = num_heads
171172 self .head_dim = dim // num_heads
172173 self .scale = dim ** - 0.5
173- self .fused_attn = False # use_fused_attn()
174+ self .fused_attn = use_fused_attn ()
174175
175176 self .proj_q = nn .Linear (dim , dim , bias = qkv_bias )
176177 self .proj_k = nn .Linear (dim , dim , bias = qkv_bias )
@@ -534,11 +535,36 @@ def _cfg(url='', **kwargs):
534535 }
535536
536537default_cfgs = generate_default_cfgs ({
537- 'cvt_13.msft_in1k' : _cfg (url = 'https://files.catbox.moe/xz97kh.pth' ),
538+ 'cvt_13.msft_in1k' : _cfg (
539+ url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-224x224-IN-1k.pth' ),
540+ 'cvt_13.msft_in1k_384' : _cfg (
541+ url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-1k.pth' ,
542+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
543+ 'cvt_13.msft_in22k_ft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-22k.pth' ,
544+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
545+
546+ 'cvt_21.msft_in1k' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-224x224-IN-1k.pth' ),
547+ 'cvt_21.msft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-1k.pth' ,
548+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
549+ 'cvt_21.msft_in22k_ft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-22k.pth' ,
550+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
551+
552+ 'cvt_w24.msft_in22k_ft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-w24-384x384-IN-22k.pth' ,
553+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
538554})
539555
540556
541557@register_model
542558def cvt_13 (pretrained = False , ** kwargs ) -> CvT :
543559 model_args = dict (depths = (1 , 2 , 10 ), dims = (64 , 192 , 384 ), num_heads = (1 , 3 , 6 ))
544560 return _create_cvt ('cvt_13' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
561+
562+ @register_model
563+ def cvt_21 (pretrained = False , ** kwargs ) -> CvT :
564+ model_args = dict (depths = (1 , 4 , 16 ), dims = (64 , 192 , 384 ), num_heads = (1 , 3 , 6 ))
565+ return _create_cvt ('cvt_21' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
566+
567+ @register_model
568+ def cvt_w24 (pretrained = False , ** kwargs ) -> CvT :
569+ model_args = dict (depths = (2 , 2 , 20 ), dims = (192 , 768 , 1024 ), num_heads = (3 , 12 , 16 ))
570+ return _create_cvt ('cvt_w24' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments