@@ -56,51 +56,66 @@ def __init__(
5656 groups = int (mid_channels / div_groups )
5757 if expansion == 1 :
5858 layers = [
59- ("conv_0" , conv_layer (
60- in_channels ,
61- mid_channels ,
62- 3 ,
63- stride = stride , # type: ignore
64- act_fn = act_fn ,
65- bn_1st = bn_1st ,
66- groups = in_channels if dw else groups ,
67- ),),
68- ("conv_1" , conv_layer (
69- mid_channels ,
70- out_channels ,
71- 3 ,
72- zero_bn = zero_bn ,
73- act_fn = False ,
74- bn_1st = bn_1st ,
75- groups = mid_channels if dw else groups ,
76- ),),
59+ (
60+ "conv_0" ,
61+ conv_layer (
62+ in_channels ,
63+ mid_channels ,
64+ 3 ,
65+ stride = stride , # type: ignore
66+ act_fn = act_fn ,
67+ bn_1st = bn_1st ,
68+ groups = in_channels if dw else groups ,
69+ ),
70+ ),
71+ (
72+ "conv_1" ,
73+ conv_layer (
74+ mid_channels ,
75+ out_channels ,
76+ 3 ,
77+ zero_bn = zero_bn ,
78+ act_fn = False ,
79+ bn_1st = bn_1st ,
80+ groups = mid_channels if dw else groups ,
81+ ),
82+ ),
7783 ]
7884 else :
7985 layers = [
80- ("conv_0" , conv_layer (
81- in_channels ,
82- mid_channels ,
83- 1 ,
84- act_fn = act_fn ,
85- bn_1st = bn_1st ,
86- ),),
87- ("conv_1" , conv_layer (
88- mid_channels ,
89- mid_channels ,
90- 3 ,
91- stride = stride ,
92- act_fn = act_fn ,
93- bn_1st = bn_1st ,
94- groups = mid_channels if dw else groups ,
95- ),),
96- ("conv_2" , conv_layer (
97- mid_channels ,
98- out_channels ,
99- 1 ,
100- zero_bn = zero_bn ,
101- act_fn = False ,
102- bn_1st = bn_1st ,
103- ),), # noqa E501
86+ (
87+ "conv_0" ,
88+ conv_layer (
89+ in_channels ,
90+ mid_channels ,
91+ 1 ,
92+ act_fn = act_fn ,
93+ bn_1st = bn_1st ,
94+ ),
95+ ),
96+ (
97+ "conv_1" ,
98+ conv_layer (
99+ mid_channels ,
100+ mid_channels ,
101+ 3 ,
102+ stride = stride ,
103+ act_fn = act_fn ,
104+ bn_1st = bn_1st ,
105+ groups = mid_channels if dw else groups ,
106+ ),
107+ ),
108+ (
109+ "conv_2" ,
110+ conv_layer (
111+ mid_channels ,
112+ out_channels ,
113+ 1 ,
114+ zero_bn = zero_bn ,
115+ act_fn = False ,
116+ bn_1st = bn_1st ,
117+ ),
118+ ), # noqa E501
104119 ]
105120 if se :
106121 layers .append (("se" , se (out_channels )))
@@ -109,16 +124,23 @@ def __init__(
109124 self .convs = nn .Sequential (OrderedDict (layers ))
110125 if stride != 1 or in_channels != out_channels :
111126 id_layers = []
112- if stride != 1 and pool is not None : # if pool - reduce by pool else stride 2 art id_conv
127+ if (
128+ stride != 1 and pool is not None
129+ ): # if pool - reduce by pool else stride 2 art id_conv
113130 id_layers .append (("pool" , pool ()))
114131 if in_channels != out_channels or (stride != 1 and pool is None ):
115- id_layers += [("id_conv" , conv_layer (
116- in_channels ,
117- out_channels ,
118- 1 ,
119- stride = 1 if pool else stride ,
120- act_fn = False ,
121- ),)]
132+ id_layers += [
133+ (
134+ "id_conv" ,
135+ conv_layer (
136+ in_channels ,
137+ out_channels ,
138+ 1 ,
139+ stride = 1 if pool else stride ,
140+ act_fn = False ,
141+ ),
142+ )
143+ ]
122144 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
123145 else :
124146 self .id_conv = None
@@ -132,16 +154,17 @@ def forward(self, x):
132154def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
133155 len_stem = len (cfg .stem_sizes )
134156 stem : List [tuple [str , nn .Module ]] = [
135- (f"conv_{ i } " , cfg .conv_layer (
136- cfg .stem_sizes [i - 1 ] if i else cfg .in_chans , # type: ignore
137- cfg .stem_sizes [i ],
138- stride = 2 if i == cfg .stem_stride_on else 1 ,
139- bn_layer = (not cfg .stem_bn_end )
140- if i == (len_stem - 1 )
141- else True ,
142- act_fn = cfg .act_fn ,
143- bn_1st = cfg .bn_1st ,
144- ),)
157+ (
158+ f"conv_{ i } " ,
159+ cfg .conv_layer (
160+ cfg .stem_sizes [i - 1 ] if i else cfg .in_chans , # type: ignore
161+ cfg .stem_sizes [i ],
162+ stride = 2 if i == cfg .stem_stride_on else 1 ,
163+ bn_layer = (not cfg .stem_bn_end ) if i == (len_stem - 1 ) else True ,
164+ act_fn = cfg .act_fn ,
165+ bn_1st = cfg .bn_1st ,
166+ ),
167+ )
145168 for i in range (len_stem )
146169 ]
147170 if cfg .stem_pool :
@@ -164,7 +187,9 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
164187 f"bl_{ block_num } " ,
165188 cfg .block (
166189 cfg .expansion , # type: ignore
167- block_chs [layer_num ] if block_num == 0 else block_chs [layer_num + 1 ],
190+ block_chs [layer_num ]
191+ if block_num == 0
192+ else block_chs [layer_num + 1 ],
168193 block_chs [layer_num + 1 ],
169194 stride if block_num == 0 else 1 ,
170195 sa = cfg .sa
@@ -191,10 +216,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
191216 return nn .Sequential (
192217 OrderedDict (
193218 [
194- (
195- f"l_{ layer_num } " ,
196- cfg .make_layer (cfg , layer_num ) # type: ignore
197- )
219+ (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
198220 for layer_num in range (len (cfg .layers ))
199221 ]
200222 )
@@ -222,7 +244,9 @@ class ModelCfg(BaseModel):
222244 layers : List [int ] = [2 , 2 , 2 , 2 ]
223245 norm : Type [nn .Module ] = nn .BatchNorm2d
224246 act_fn : Type [nn .Module ] = nn .ReLU
225- pool : Callable [[Any ], nn .Module ] = partial (nn .AvgPool2d , kernel_size = 2 , ceil_mode = True )
247+ pool : Callable [[Any ], nn .Module ] = partial (
248+ nn .AvgPool2d , kernel_size = 2 , ceil_mode = True
249+ )
226250 expansion : int = 1
227251 groups : int = 1
228252 dw : bool = False
@@ -235,7 +259,9 @@ class ModelCfg(BaseModel):
235259 zero_bn : bool = True
236260 stem_stride_on : int = 0
237261 stem_sizes : List [int ] = [32 , 32 , 64 ]
238- stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
262+ stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
263+ nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
264+ )
239265 stem_bn_end : bool = False
240266 init_cnn : Callable [[nn .Module ], None ] = init_cnn
241267 make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
@@ -301,7 +327,7 @@ def from_cfg(cls, cfg: ModelCfg):
301327
302328 def __call__ (self ):
303329 model_name = self .name or self .__class__ .__name__
304- named_sequential = type (model_name , (nn .Sequential , ), {})
330+ named_sequential = type (model_name , (nn .Sequential ,), {})
305331 model = named_sequential (
306332 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
307333 )
@@ -314,7 +340,8 @@ def __call__(self):
314340 def _get_extra_repr (self ) -> str :
315341 return " " .join (
316342 f"{ field } : { self ._get_str_value (field )} ,"
317- for field in self .__fields_set__ if field != "name"
343+ for field in self .__fields_set__
344+ if field != "name"
318345 )[:- 1 ]
319346
320347 def __repr__ (self ):
0 commit comments