@@ -1723,7 +1723,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17231723 input_size = (3 , 256 , 256 )),
17241724 'vit_medium_patch16_reg4_gap_256' : _cfg (
17251725 input_size = (3 , 256 , 256 )),
1726- 'vit_base_patch16_reg8_gap_256' : _cfg (input_size = (3 , 256 , 256 )),
1726+ 'vit_base_patch16_reg4_gap_256' : _cfg (
1727+ input_size = (3 , 256 , 256 )),
1728+ 'vit_so150m_patch16_reg4_gap_256' : _cfg (
1729+ input_size = (3 , 256 , 256 )),
1730+ 'vit_so150m_patch16_reg4_map_256' : _cfg (
1731+ input_size = (3 , 256 , 256 )),
17271732}
17281733
17291734_quick_gelu_cfgs = [
@@ -2623,13 +2628,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
26232628
26242629
26252630@register_model
2626- def vit_base_patch16_reg8_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2631+ def vit_base_patch16_reg4_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
26272632 model_args = dict (
26282633 patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , class_token = False ,
2629- no_embed_class = True , global_pool = 'avg' , reg_tokens = 8 ,
2634+ no_embed_class = True , global_pool = 'avg' , reg_tokens = 4 ,
2635+ )
2636+ model = _create_vision_transformer (
2637+ 'vit_base_patch16_reg4_gap_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2638+ return model
2639+
2640+
2641+ @register_model
2642+ def vit_so150m_patch16_reg4_map_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2643+ model_args = dict (
2644+ patch_size = 16 , embed_dim = 896 , depth = 18 , num_heads = 14 , mlp_ratio = 2.572 ,
2645+ class_token = False , reg_tokens = 4 , global_pool = 'map' ,
2646+ )
2647+ model = _create_vision_transformer (
2648+ 'vit_so150m_patch16_reg4_map_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2649+ return model
2650+
2651+
2652+ @register_model
2653+ def vit_so150m_patch16_reg4_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2654+ model_args = dict (
2655+ patch_size = 16 , embed_dim = 896 , depth = 18 , num_heads = 14 , mlp_ratio = 2.572 ,
2656+ class_token = False , reg_tokens = 4 , global_pool = 'avg' , fc_norm = False ,
26302657 )
26312658 model = _create_vision_transformer (
2632- 'vit_base_patch16_reg8_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2659+ 'vit_so150m_patch16_reg4_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
26332660 return model
26342661
26352662
0 commit comments