Skip to content

Commit 87fec3d

Browse files
committed
Update experimental vit model configs
1 parent 7d3c2dc commit 87fec3d

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

timm/models/vision_transformer.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17231723
input_size=(3, 256, 256)),
17241724
'vit_medium_patch16_reg4_gap_256': _cfg(
17251725
input_size=(3, 256, 256)),
1726-
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
1726+
'vit_base_patch16_reg4_gap_256': _cfg(
1727+
input_size=(3, 256, 256)),
1728+
'vit_so150m_patch16_reg4_gap_256': _cfg(
1729+
input_size=(3, 256, 256)),
1730+
'vit_so150m_patch16_reg4_map_256': _cfg(
1731+
input_size=(3, 256, 256)),
17271732
}
17281733

17291734
_quick_gelu_cfgs = [
@@ -2623,13 +2628,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
26232628

26242629

26252630
@register_model
2626-
def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2631+
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
26272632
model_args = dict(
26282633
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
2629-
no_embed_class=True, global_pool='avg', reg_tokens=8,
2634+
no_embed_class=True, global_pool='avg', reg_tokens=4,
2635+
)
2636+
model = _create_vision_transformer(
2637+
'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2638+
return model
2639+
2640+
2641+
@register_model
2642+
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2643+
model_args = dict(
2644+
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
2645+
class_token=False, reg_tokens=4, global_pool='map',
2646+
)
2647+
model = _create_vision_transformer(
2648+
'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
2649+
return model
2650+
2651+
2652+
@register_model
2653+
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2654+
model_args = dict(
2655+
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
2656+
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
26302657
)
26312658
model = _create_vision_transformer(
2632-
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2659+
'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
26332660
return model
26342661

26352662

0 commit comments

Comments
 (0)