4545
4646from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4747from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , Mlp , GlobalResponseNormMlp , \
48- LayerNorm2d , LayerNorm , create_conv2d , get_act_layer , make_divisible , to_ntuple
48+ LayerNorm2d , LayerNorm , RmsNorm2d , RmsNorm , create_conv2d , get_act_layer , get_norm_layer , make_divisible , to_ntuple
4949from timm .layers import NormMlpClassifierHead , ClassifierHead
5050from ._builder import build_model_with_cfg
5151from ._features import feature_take_indices
@@ -289,24 +289,27 @@ def __init__(
289289 super ().__init__ ()
290290 assert output_stride in (8 , 16 , 32 )
291291 kernel_sizes = to_ntuple (4 )(kernel_sizes )
292- if norm_layer is None :
293- norm_layer = LayerNorm2d
294- norm_layer_cl = norm_layer if conv_mlp else LayerNorm
292+ use_rms = isinstance (norm_layer , str ) and norm_layer .startswith ('rmsnorm' )
293+ if norm_layer is None or use_rms :
294+ norm_layer = RmsNorm2d if use_rms else LayerNorm2d
295+ norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm )
295296 if norm_eps is not None :
296297 norm_layer = partial (norm_layer , eps = norm_eps )
297298 norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
298299 else :
299300 assert conv_mlp ,\
300301 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
302+ norm_layer = get_norm_layer (norm_layer )
301303 norm_layer_cl = norm_layer
302304 if norm_eps is not None :
303305 norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
306+ act_layer = get_act_layer (act_layer )
304307
305308 self .num_classes = num_classes
306309 self .drop_rate = drop_rate
307310 self .feature_info = []
308311
309- assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' )
312+ assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' , 'overlap_act' )
310313 if stem_type == 'patch' :
311314 # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312315 self .stem = nn .Sequential (
@@ -316,11 +319,12 @@ def __init__(
316319 stem_stride = patch_size
317320 else :
318321 mid_chs = make_divisible (dims [0 ] // 2 ) if 'tiered' in stem_type else dims [0 ]
319- self .stem = nn .Sequential (
322+ self .stem = nn .Sequential (* filter ( None , [
320323 nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
324+ act_layer () if 'act' in stem_type else None ,
321325 nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
322326 norm_layer (dims [0 ]),
323- )
327+ ]) )
324328 stem_stride = 4
325329
326330 self .stages = nn .Sequential ()
@@ -592,6 +596,13 @@ def _cfgv2(url='', **kwargs):
592596 hf_hub_id = 'timm/' ,
593597 crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
594598
599+ 'convnext_zepto_rms.ra4_e3600_r224_in1k' : _cfg (
600+ hf_hub_id = 'timm/' ,
601+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 )),
602+ 'convnext_zepto_rms_ols.untrained' : _cfg (
603+ # hf_hub_id='timm/',
604+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
605+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
595606 'convnext_atto.d2_in1k' : _cfg (
596607 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth' ,
597608 hf_hub_id = 'timm/' ,
@@ -600,6 +611,9 @@ def _cfgv2(url='', **kwargs):
600611 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth' ,
601612 hf_hub_id = 'timm/' ,
602613 test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
614+ 'convnext_atto_rms.untrained' : _cfg (
615+ #hf_hub_id='timm/',
616+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
603617 'convnext_femto.d1_in1k' : _cfg (
604618 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth' ,
605619 hf_hub_id = 'timm/' ,
@@ -968,6 +982,23 @@ def _cfgv2(url='', **kwargs):
968982})
969983
970984
985+ @register_model
986+ def convnext_zepto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
987+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
988+ model_args = dict (depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d' )
989+ model = _create_convnext ('convnext_zepto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
990+ return model
991+
992+
993+ @register_model
994+ def convnext_zepto_rms_ols (pretrained = False , ** kwargs ) -> ConvNeXt :
995+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
996+ model_args = dict (
997+ depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d' , stem_type = 'overlap_act' )
998+ model = _create_convnext ('convnext_zepto_rms_ols' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
999+ return model
1000+
1001+
9711002@register_model
9721003def convnext_atto (pretrained = False , ** kwargs ) -> ConvNeXt :
9731004 # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@@ -984,6 +1015,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
9841015 return model
9851016
9861017
1018+ @register_model
1019+ def convnext_atto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
1020+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1021+ model_args = dict (depths = (2 , 2 , 6 , 2 ), dims = (40 , 80 , 160 , 320 ), conv_mlp = True , norm_layer = 'rmsnorm2d' )
1022+ model = _create_convnext ('convnext_atto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1023+ return model
1024+
1025+
9871026@register_model
9881027def convnext_femto (pretrained = False , ** kwargs ) -> ConvNeXt :
9891028 # timm femto variant
0 commit comments