1- from collections import OrderedDict
21from typing import Callable , Union
32
43from torch import nn
54
5+ from .helpers import nn_seq
66from .layers import ConvBnAct , get_act
77from .model_constructor import ModelCfg , ModelConstructor
88
99__all__ = [
1010 "XResBlock" ,
11- "ModelConstructor" ,
1211 "XResNet34" ,
1312 "XResNet50" ,
13+ "YaResNet" ,
14+ "YaResNet34" ,
15+ "YaResNet50" ,
1416]
1517
1618
17- # TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
18-
19-
2019class XResBlock (nn .Module ):
2120 """Universal XResnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2221
@@ -109,7 +108,7 @@ def __init__(
109108 layers .append (("se" , se (out_channels )))
110109 if sa :
111110 layers .append (("sa" , sa (out_channels )))
112- self .convs = nn . Sequential ( OrderedDict ( layers ) )
111+ self .convs = nn_seq ( layers )
113112 if stride != 1 or in_channels != out_channels :
114113 id_layers = []
115114 if (
@@ -129,7 +128,7 @@ def __init__(
129128 ),
130129 )
131130 ]
132- self .id_conv = nn . Sequential ( OrderedDict ( id_layers ) )
131+ self .id_conv = nn_seq ( id_layers )
133132 else :
134133 self .id_conv = None
135134 self .act_fn = get_act (act_fn )
@@ -240,7 +239,7 @@ def __init__(
240239 layers .append (("se" , se (out_channels ))) # type: ignore
241240 if sa :
242241 layers .append (("sa" , sa (out_channels ))) # type: ignore
243- self .convs = nn . Sequential ( OrderedDict ( layers ) )
242+ self .convs = nn_seq ( layers )
244243 if in_channels != out_channels :
245244 self .id_conv = conv_layer (
246245 in_channels ,
@@ -281,7 +280,7 @@ def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
281280 stem .append (("stem_pool" , cfg .stem_pool ()))
282281 if cfg .stem_bn_end :
283282 stem .append (("norm" , cfg .norm (cfg .stem_sizes [- 1 ]))) # type: ignore
284- return nn . Sequential ( OrderedDict ( stem ) )
283+ return nn_seq ( stem )
285284
286285
287286def make_layer (cfg : ModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
@@ -290,47 +289,39 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
290289 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
291290 num_blocks = cfg .layers [layer_num ]
292291 block_chs = [cfg .stem_sizes [- 1 ] // cfg .expansion ] + cfg .block_sizes
293- return nn .Sequential (
294- OrderedDict (
295- [
296- (
297- f"bl_{ block_num } " ,
298- cfg .block (
299- cfg .expansion , # type: ignore
300- block_chs [layer_num ]
301- if block_num == 0
302- else block_chs [layer_num + 1 ],
303- block_chs [layer_num + 1 ],
304- stride if block_num == 0 else 1 ,
305- sa = cfg .sa
306- if (block_num == num_blocks - 1 ) and layer_num == 0
307- else None ,
308- conv_layer = cfg .conv_layer ,
309- act_fn = cfg .act_fn ,
310- pool = cfg .pool ,
311- zero_bn = cfg .zero_bn ,
312- bn_1st = cfg .bn_1st ,
313- groups = cfg .groups ,
314- div_groups = cfg .div_groups ,
315- dw = cfg .dw ,
316- se = cfg .se ,
317- ),
318- )
319- for block_num in range (num_blocks )
320- ]
292+ return nn_seq (
293+ (
294+ f"bl_{ block_num } " ,
295+ cfg .block (
296+ cfg .expansion , # type: ignore
297+ block_chs [layer_num ]
298+ if block_num == 0
299+ else block_chs [layer_num + 1 ],
300+ block_chs [layer_num + 1 ],
301+ stride if block_num == 0 else 1 ,
302+ sa = cfg .sa
303+ if (block_num == num_blocks - 1 ) and layer_num == 0
304+ else None ,
305+ conv_layer = cfg .conv_layer ,
306+ act_fn = cfg .act_fn ,
307+ pool = cfg .pool ,
308+ zero_bn = cfg .zero_bn ,
309+ bn_1st = cfg .bn_1st ,
310+ groups = cfg .groups ,
311+ div_groups = cfg .div_groups ,
312+ dw = cfg .dw ,
313+ se = cfg .se ,
314+ ),
321315 )
316+ for block_num in range (num_blocks )
322317 )
323318
324319
325320def make_body (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
326321 """Create model body."""
327- return nn .Sequential (
328- OrderedDict (
329- [
330- (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
331- for layer_num in range (len (cfg .layers ))
332- ]
333- )
322+ return nn_seq (
323+ (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
324+ for layer_num in range (len (cfg .layers ))
334325 )
335326
336327
0 commit comments