@@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
4343 'url' : url , 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (7 , 7 ),
4444 'crop_pct' : 0.875 , 'interpolation' : 'bicubic' ,
4545 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
46- 'first_conv' : 'stem.conv ' , 'classifier' : 'head.fc' ,
46+ 'first_conv' : 'stem.conv1 ' , 'classifier' : 'head.fc' ,
4747 'fixed_input_size' : True ,
4848 ** kwargs
4949 }
@@ -106,7 +106,7 @@ def __init__(
106106 dim_out = None ,
107107 reduction = 'conv' ,
108108 act_layer = nn .GELU ,
109- norm_layer = LayerNorm2d ,
109+ norm_layer = LayerNorm2d , # NOTE in NCHW
110110 ):
111111 super ().__init__ ()
112112 dim_out = dim_out or dim
@@ -163,12 +163,10 @@ def __init__(
163163 self ,
164164 in_chs : int = 3 ,
165165 out_chs : int = 96 ,
166- act_layer : str = 'gelu' ,
167- norm_layer : str = 'layernorm2d' , # NOTE norm for NCHW
166+ act_layer : Callable = nn . GELU ,
167+ norm_layer : Callable = LayerNorm2d , # NOTE stem in NCHW
168168 ):
169169 super ().__init__ ()
170- act_layer = get_act_layer (act_layer )
171- norm_layer = get_norm_layer (norm_layer )
172170 self .conv1 = nn .Conv2d (in_chs , out_chs , kernel_size = 3 , stride = 2 , padding = 1 )
173171 self .down = Downsample2d (out_chs , act_layer = act_layer , norm_layer = norm_layer )
174172
@@ -333,15 +331,11 @@ def __init__(
333331 proj_drop : float = 0. ,
334332 attn_drop : float = 0. ,
335333 drop_path : Union [List [float ], float ] = 0.0 ,
336- act_layer : str = 'gelu' ,
337- norm_layer : str = 'layernorm2d' ,
338- norm_layer_cl : str = 'layernorm' ,
334+ act_layer : Callable = nn . GELU ,
335+ norm_layer : Callable = nn . LayerNorm ,
336+ norm_layer_cl : Callable = LayerNorm2d ,
339337 ):
340338 super ().__init__ ()
341- act_layer = get_act_layer (act_layer )
342- norm_layer = get_norm_layer (norm_layer )
343- norm_layer_cl = get_norm_layer (norm_layer_cl )
344-
345339 if downsample :
346340 self .downsample = Downsample2d (
347341 dim = dim ,
@@ -421,8 +415,13 @@ def __init__(
421415 act_layer : str = 'gelu' ,
422416 norm_layer : str = 'layernorm2d' ,
423417 norm_layer_cl : str = 'layernorm' ,
418+ norm_eps : float = 1e-5 ,
424419 ):
425420 super ().__init__ ()
421+ act_layer = get_act_layer (act_layer )
422+ norm_layer = partial (get_norm_layer (norm_layer ), eps = norm_eps )
423+ norm_layer_cl = partial (get_norm_layer (norm_layer_cl ), eps = norm_eps )
424+
426425 img_size = to_2tuple (img_size )
427426 feat_size = tuple (d // 4 for d in img_size ) # stem reduction by 4
428427 self .global_pool = global_pool
@@ -432,7 +431,11 @@ def __init__(
432431 self .num_features = int (embed_dim * 2 ** (num_stages - 1 ))
433432
434433 self .stem = Stem (
435- in_chs = in_chans , out_chs = embed_dim , act_layer = act_layer , norm_layer = norm_layer )
434+ in_chs = in_chans ,
435+ out_chs = embed_dim ,
436+ act_layer = act_layer ,
437+ norm_layer = norm_layer
438+ )
436439
437440 dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
438441 stages = []
0 commit comments