@@ -56,7 +56,7 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
5656 self .convs = nn .Sequential (OrderedDict (layers ))
5757 if stride != 1 or in_channels != out_channels :
5858 id_layers = []
59- if stride != 1 and pool is not None :
59+ if stride != 1 and pool is not None : # if pool - reduce by pool else stride 2 art id_conv
6060 id_layers .append (("pool" , pool ))
6161 if in_channels != out_channels or (stride != 1 and pool is None ):
6262 id_layers += [("id_conv" , conv_layer (
@@ -86,25 +86,33 @@ def _make_stem(self):
8686 return nn .Sequential (OrderedDict (stem ))
8787
8888
89- def _make_layer (self , expansion , in_channels , out_channels , blocks , stride , sa ):
90- layers = [(f"bl_{ i } " , self .block (expansion , in_channels if i == 0 else out_channels , out_channels ,
91- stride if i == 0 else 1 , sa = sa if i == blocks - 1 else None ,
92- conv_layer = self .conv_layer , act_fn = self .act_fn , pool = self .pool ,
93- zero_bn = self .zero_bn , bn_1st = self .bn_1st ,
94- groups = self .groups , div_groups = self .div_groups ,
95- dw = self .dw , se = self .se ))
96- for i in range (blocks )]
97- return nn .Sequential (OrderedDict (layers ))
89+ def _make_layer (self , layer_id : int ) -> nn .Module :
90+ # expansion, in_channels, out_channels, blocks, stride, sa):
91+ stride = 1 if self .stem_pool and layer_id == 0 else 2 # if no pool on stem - stride = 2 for first layer block in body
92+ num_blocks = self .layers [layer_id ]
93+ return nn .Sequential (OrderedDict ([
94+ (f"bl_{ block_num } " , self .block (
95+ self .expansion ,
96+ self .block_sizes [layer_id ] if block_num == 0 else self .block_sizes [layer_id + 1 ],
97+ self .block_sizes [layer_id + 1 ],
98+ stride if block_num == 0 else 1 ,
99+ sa = self .sa if block_num == num_blocks - 1 else None ,
100+ conv_layer = self .conv_layer ,
101+ act_fn = self .act_fn ,
102+ pool = self .pool ,
103+ zero_bn = self .zero_bn , bn_1st = self .bn_1st ,
104+ groups = self .groups , div_groups = self .div_groups ,
105+ dw = self .dw , se = self .se
106+ ))
107+ for block_num in range (num_blocks )
108+ ]))
98109
99110
100111def _make_body (self ):
101- stride = 1 if self .stem_pool else 1 # if no pool on stem - stride = 2 for first block in body
102- blocks = [(f"l_{ i } " , self ._make_layer (self , self .expansion ,
103- in_channels = self .block_sizes [i ], out_channels = self .block_sizes [i + 1 ],
104- blocks = l , stride = stride if i == 0 else 2 ,
105- sa = self .sa if i == 0 else None ))
106- for i , l in enumerate (self .layers )]
107- return nn .Sequential (OrderedDict (blocks ))
112+ return nn .Sequential (OrderedDict ([
113+ (f"l_{ layer_num } " , self ._make_layer (self , layer_num ))
114+ for layer_num in range (len (self .layers ))
115+ ]))
108116
109117
110118def _make_head (self ):
@@ -140,7 +148,7 @@ def __init__(self, name='MC', in_chans=3, num_classes=1000,
140148 ):
141149 super ().__init__ ()
142150 # se can be bool, int (0, 1) or nn.Module
143- # se_module - deprecated. Leaved for worning and checks.
151+ # se_module - deprecated. Leaved for warning and checks.
144152 # if stem_pool is False - no pool at stem
145153
146154 params = locals ()
0 commit comments