4545import torch .nn as nn
4646
4747from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
48- from timm .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , Mlp , GlobalResponseNormMlp , \
48+ from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , Mlp , GlobalResponseNormMlp , \
4949 LayerNorm2d , LayerNorm , create_conv2d , get_act_layer , make_divisible , to_ntuple
5050from timm .layers import NormMlpClassifierHead , ClassifierHead
5151from ._builder import build_model_with_cfg
5656__all__ = ['ConvNeXt' ] # model_registry will add each entrypoint fn to this
5757
5858
59+ class Downsample (nn .Module ):
60+
61+ def __init__ (self , in_chs , out_chs , stride = 1 , dilation = 1 ):
62+ super ().__init__ ()
63+ avg_stride = stride if dilation == 1 else 1
64+ if stride > 1 or dilation > 1 :
65+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn .AvgPool2d
66+ self .pool = avg_pool_fn (2 , avg_stride , ceil_mode = True , count_include_pad = False )
67+ else :
68+ self .pool = nn .Identity ()
69+
70+ if in_chs != out_chs :
71+ self .conv = create_conv2d (in_chs , out_chs , 1 , stride = 1 )
72+ else :
73+ self .conv = nn .Identity ()
74+
75+ def forward (self , x ):
76+ x = self .pool (x )
77+ x = self .conv (x )
78+ return x
79+
80+
5981class ConvNeXtBlock (nn .Module ):
6082 """ ConvNeXt Block
6183 There are two equivalent implementations:
@@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
6587 Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
6688 choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
6789 is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
68-
69- Args:
70- in_chs (int): Number of input channels.
71- drop_path (float): Stochastic depth rate. Default: 0.0
72- ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
7390 """
7491
7592 def __init__ (
7693 self ,
77- in_chs ,
78- out_chs = None ,
79- kernel_size = 7 ,
80- stride = 1 ,
81- dilation = 1 ,
82- mlp_ratio = 4 ,
83- conv_mlp = False ,
84- conv_bias = True ,
85- use_grn = False ,
86- ls_init_value = 1e-6 ,
87- act_layer = 'gelu' ,
88- norm_layer = None ,
89- drop_path = 0. ,
94+ in_chs : int ,
95+ out_chs : Optional [ int ] = None ,
96+ kernel_size : int = 7 ,
97+ stride : int = 1 ,
98+ dilation : Union [ int , Tuple [ int , int ]] = ( 1 , 1 ) ,
99+ mlp_ratio : float = 4 ,
100+ conv_mlp : bool = False ,
101+ conv_bias : bool = True ,
102+ use_grn : bool = False ,
103+ ls_init_value : Optional [ float ] = 1e-6 ,
104+ act_layer : Union [ str , Callable ] = 'gelu' ,
105+ norm_layer : Optional [ Callable ] = None ,
106+ drop_path : float = 0. ,
90107 ):
108+ """
109+
110+ Args:
111+ in_chs: Block input channels.
112+ out_chs: Block output channels (same as in_chs if None).
113+ kernel_size: Depthwise convolution kernel size.
114+ stride: Stride of depthwise convolution.
115+ dilation: Tuple specifying input and output dilation of block.
116+ mlp_ratio: MLP expansion ratio.
117+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
118+ conv_bias: Apply bias for all convolution (linear) layers.
119+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
120+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
121+ act_layer: Activation layer.
122+ norm_layer: Normalization layer (defaults to LN if not specified).
123+ drop_path: Stochastic depth probability.
124+ """
91125 super ().__init__ ()
92126 out_chs = out_chs or in_chs
127+ dilation = to_ntuple (2 )(dilation )
93128 act_layer = get_act_layer (act_layer )
94129 if not norm_layer :
95130 norm_layer = LayerNorm2d if conv_mlp else LayerNorm
96131 mlp_layer = partial (GlobalResponseNormMlp if use_grn else Mlp , use_conv = conv_mlp )
97132 self .use_conv_mlp = conv_mlp
98133 self .conv_dw = create_conv2d (
99- in_chs , out_chs , kernel_size = kernel_size , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
134+ in_chs ,
135+ out_chs ,
136+ kernel_size = kernel_size ,
137+ stride = stride ,
138+ dilation = dilation [0 ],
139+ depthwise = True ,
140+ bias = conv_bias ,
141+ )
100142 self .norm = norm_layer (out_chs )
101143 self .mlp = mlp_layer (out_chs , int (mlp_ratio * out_chs ), act_layer = act_layer )
102144 self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs )) if ls_init_value is not None else None
145+ if in_chs != out_chs or stride != 1 or dilation [0 ] != dilation [1 ]:
146+ self .shortcut = Downsample (in_chs , out_chs , stride = stride , dilation = dilation [0 ])
147+ else :
148+ self .shortcut = nn .Identity ()
103149 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
104150
105151 def forward (self , x ):
@@ -116,7 +162,7 @@ def forward(self, x):
116162 if self .gamma is not None :
117163 x = x .mul (self .gamma .reshape (1 , - 1 , 1 , 1 ))
118164
119- x = self .drop_path (x ) + shortcut
165+ x = self .drop_path (x ) + self . shortcut ( shortcut )
120166 return x
121167
122168
@@ -148,8 +194,14 @@ def __init__(
148194 self .downsample = nn .Sequential (
149195 norm_layer (in_chs ),
150196 create_conv2d (
151- in_chs , out_chs , kernel_size = ds_ks , stride = stride ,
152- dilation = dilation [0 ], padding = pad , bias = conv_bias ),
197+ in_chs ,
198+ out_chs ,
199+ kernel_size = ds_ks ,
200+ stride = stride ,
201+ dilation = dilation [0 ],
202+ padding = pad ,
203+ bias = conv_bias ,
204+ ),
153205 )
154206 in_chs = out_chs
155207 else :
0 commit comments