4646from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4747from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , Mlp , GlobalResponseNormMlp , \
4848 LayerNorm2d , LayerNorm , RmsNorm2d , RmsNorm , create_conv2d , get_act_layer , get_norm_layer , make_divisible , to_ntuple
49+ from timm .layers import SimpleNorm2d , SimpleNorm
4950from timm .layers import NormMlpClassifierHead , ClassifierHead
5051from ._builder import build_model_with_cfg
5152from ._features import feature_take_indices
@@ -233,6 +234,34 @@ def forward(self, x):
233234 x = self .blocks (x )
234235 return x
235236
237+ # map of norm layers with NCHW (2D) and channels last variants
238+ _NORM_MAP = {
239+ 'layernorm' : (LayerNorm2d , LayerNorm ),
240+ 'layernorm2d' : (LayerNorm2d , LayerNorm ),
241+ 'simplenorm' : (SimpleNorm2d , SimpleNorm ),
242+ 'simplenorm2d' : (SimpleNorm2d , SimpleNorm ),
243+ 'rmsnorm' : (RmsNorm2d , RmsNorm ),
244+ 'rmsnorm2d' : (RmsNorm2d , RmsNorm ),
245+ }
246+
247+
248+ def _get_norm_layers (norm_layer : Union [Callable , str ], conv_mlp : bool , norm_eps : float ):
249+ norm_layer = norm_layer or 'layernorm'
250+ if norm_layer in _NORM_MAP :
251+ norm_layer_cl = _NORM_MAP [norm_layer ][0 ] if conv_mlp else _NORM_MAP [norm_layer ][1 ]
252+ norm_layer = _NORM_MAP [norm_layer ][0 ]
253+ if norm_eps is not None :
254+ norm_layer = partial (norm_layer , eps = norm_eps )
255+ norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
256+ else :
257+ assert conv_mlp , \
258+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
259+ norm_layer = get_norm_layer (norm_layer )
260+ norm_layer_cl = norm_layer
261+ if norm_eps is not None :
262+ norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
263+ return norm_layer , norm_layer_cl
264+
236265
237266class ConvNeXt (nn .Module ):
238267 r""" ConvNeXt
@@ -289,20 +318,7 @@ def __init__(
289318 super ().__init__ ()
290319 assert output_stride in (8 , 16 , 32 )
291320 kernel_sizes = to_ntuple (4 )(kernel_sizes )
292- use_rms = isinstance (norm_layer , str ) and norm_layer .startswith ('rmsnorm' )
293- if norm_layer is None or use_rms :
294- norm_layer = RmsNorm2d if use_rms else LayerNorm2d
295- norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm )
296- if norm_eps is not None :
297- norm_layer = partial (norm_layer , eps = norm_eps )
298- norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
299- else :
300- assert conv_mlp ,\
301- 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
302- norm_layer = get_norm_layer (norm_layer )
303- norm_layer_cl = norm_layer
304- if norm_eps is not None :
305- norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
321+ norm_layer , norm_layer_cl = _get_norm_layers (norm_layer , conv_mlp , norm_eps )
306322 act_layer = get_act_layer (act_layer )
307323
308324 self .num_classes = num_classes
@@ -975,7 +991,7 @@ def _cfgv2(url='', **kwargs):
975991@register_model
976992def convnext_zepto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
977993 # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
978- model_args = dict (depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d ' )
994+ model_args = dict (depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'simplenorm ' )
979995 model = _create_convnext ('convnext_zepto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
980996 return model
981997
@@ -984,7 +1000,7 @@ def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
9841000def convnext_zepto_rms_ols (pretrained = False , ** kwargs ) -> ConvNeXt :
9851001 # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
9861002 model_args = dict (
987- depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d ' , stem_type = 'overlap_act' )
1003+ depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'simplenorm ' , stem_type = 'overlap_act' )
9881004 model = _create_convnext ('convnext_zepto_rms_ols' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
9891005 return model
9901006
0 commit comments