1212import torch
1313import torch .nn as nn
1414
15- from timm .layers import SpaceToDepth , BlurPool2d , ClassifierHead , SEModule ,\
16- ConvNormActAa , ConvNormAct , DropPath
15+ from timm .layers import SpaceToDepth , BlurPool2d , ClassifierHead , SEModule , ConvNormAct , DropPath
1716from ._builder import build_model_with_cfg
1817from ._manipulate import checkpoint_seq
1918from ._registry import register_model , generate_default_cfgs , register_model_deprecations
@@ -39,13 +38,8 @@ def __init__(
3938 self .stride = stride
4039 act_layer = partial (nn .LeakyReLU , negative_slope = 1e-3 )
4140
42- if stride == 1 :
43- self .conv1 = ConvNormAct (inplanes , planes , kernel_size = 3 , stride = 1 , act_layer = act_layer )
44- else :
45- self .conv1 = ConvNormActAa (
46- inplanes , planes , kernel_size = 3 , stride = 2 , act_layer = act_layer , aa_layer = aa_layer )
47-
48- self .conv2 = ConvNormAct (planes , planes , kernel_size = 3 , stride = 1 , apply_act = False , act_layer = None )
41+ self .conv1 = ConvNormAct (inplanes , planes , kernel_size = 3 , stride = stride , act_layer = act_layer , aa_layer = aa_layer )
42+ self .conv2 = ConvNormAct (planes , planes , kernel_size = 3 , stride = 1 , apply_act = False )
4943 self .act = nn .ReLU (inplace = True )
5044
5145 rd_chs = max (planes * self .expansion // 4 , 64 )
@@ -87,18 +81,14 @@ def __init__(
8781
8882 self .conv1 = ConvNormAct (
8983 inplanes , planes , kernel_size = 1 , stride = 1 , act_layer = act_layer )
90- if stride == 1 :
91- self .conv2 = ConvNormAct (
92- planes , planes , kernel_size = 3 , stride = 1 , act_layer = act_layer )
93- else :
94- self .conv2 = ConvNormActAa (
95- planes , planes , kernel_size = 3 , stride = 2 , act_layer = act_layer , aa_layer = aa_layer )
84+ self .conv2 = ConvNormAct (
85+ planes , planes , kernel_size = 3 , stride = stride , act_layer = act_layer , aa_layer = aa_layer )
9686
9787 reduction_chs = max (planes * self .expansion // 8 , 64 )
9888 self .se = SEModule (planes , rd_channels = reduction_chs ) if use_se else None
9989
10090 self .conv3 = ConvNormAct (
101- planes , planes * self .expansion , kernel_size = 1 , stride = 1 , apply_act = False , act_layer = None )
91+ planes , planes * self .expansion , kernel_size = 1 , stride = 1 , apply_act = False )
10292
10393 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn .Identity ()
10494 self .act = nn .ReLU (inplace = True )
@@ -204,7 +194,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non
204194 # avg pooling before 1x1 conv
205195 layers .append (nn .AvgPool2d (kernel_size = 2 , stride = 2 , ceil_mode = True , count_include_pad = False ))
206196 layers += [ConvNormAct (
207- self .inplanes , planes * block .expansion , kernel_size = 1 , stride = 1 , apply_act = False , act_layer = None )]
197+ self .inplanes , planes * block .expansion , kernel_size = 1 , stride = 1 , apply_act = False )]
208198 downsample = nn .Sequential (* layers )
209199
210200 layers = []
0 commit comments