@@ -25,35 +25,35 @@ def init_cnn(module: nn.Module):
2525class ResBlock (nn .Module ):
2626 '''Resnet block'''
2727
28- def __init__ (self , expansion , ni , nh , stride = 1 ,
28+ def __init__ (self , expansion , in_channels , mid_channels , stride = 1 ,
2929 conv_layer = ConvLayer , act_fn = act_fn , zero_bn = True , bn_1st = True ,
3030 pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False , sym = False ,
3131 groups = 1 , dw = False , div_groups = None ,
3232 se_module = SEModule , se = False , se_reduction = 16
3333 ):
3434 super ().__init__ ()
35- nf , ni = nh * expansion , ni * expansion
35+ out_channels , in_channels = mid_channels * expansion , in_channels * expansion
3636 if div_groups is not None : # check if grops != 1 and div_groups
37- groups = int (nh / div_groups )
37+ groups = int (mid_channels / div_groups )
3838 if expansion == 1 :
39- layers = [("conv_0" , conv_layer (ni , nh , 3 , stride = stride ,
40- act_fn = act_fn , bn_1st = bn_1st , groups = ni if dw else groups )),
41- ("conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn ,
42- act = False , bn_1st = bn_1st , groups = nh if dw else groups ))
39+ layers = [("conv_0" , conv_layer (in_channels , mid_channels , 3 , stride = stride ,
40+ act_fn = act_fn , bn_1st = bn_1st , groups = in_channels if dw else groups )),
41+ ("conv_1" , conv_layer (mid_channels , out_channels , 3 , zero_bn = zero_bn ,
42+ act = False , bn_1st = bn_1st , groups = mid_channels if dw else groups ))
4343 ]
4444 else :
45- layers = [("conv_0" , conv_layer (ni , nh , 1 , act_fn = act_fn , bn_1st = bn_1st )),
46- ("conv_1" , conv_layer (nh , nh , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st ,
47- groups = nh if dw else groups )),
48- ("conv_2" , conv_layer (nh , nf , 1 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
45+ layers = [("conv_0" , conv_layer (in_channels , mid_channels , 1 , act_fn = act_fn , bn_1st = bn_1st )),
46+ ("conv_1" , conv_layer (mid_channels , mid_channels , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st ,
47+ groups = mid_channels if dw else groups )),
48+ ("conv_2" , conv_layer (mid_channels , out_channels , 1 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
4949 ]
5050 if se :
51- layers .append (('se' , se_module (nf , se_reduction )))
51+ layers .append (('se' , se_module (out_channels , se_reduction )))
5252 if sa :
53- layers .append (('sa' , SimpleSelfAttention (nf , ks = 1 , sym = sym )))
53+ layers .append (('sa' , SimpleSelfAttention (out_channels , ks = 1 , sym = sym )))
5454 self .convs = nn .Sequential (OrderedDict (layers ))
5555 self .pool = noop if stride == 1 else pool
56- self .idconv = noop if ni == nf else conv_layer (ni , nf , 1 , act = False )
56+ self .idconv = noop if in_channels == out_channels else conv_layer (in_channels , out_channels , 1 , act = False )
5757 self .act_fn = act_fn
5858
5959 def forward (self , x ):
@@ -73,8 +73,8 @@ def _make_stem(self):
7373 return nn .Sequential (OrderedDict (stem ))
7474
7575
76- def _make_layer (self , expansion , ni , nf , blocks , stride , sa ):
77- layers = [(f"bl_{ i } " , self .block (expansion , ni if i == 0 else nf , nf ,
76+ def _make_layer (self , expansion , in_channels , out_channels , blocks , stride , sa ):
77+ layers = [(f"bl_{ i } " , self .block (expansion , in_channels if i == 0 else out_channels , out_channels ,
7878 stride if i == 0 else 1 , sa = sa if i == blocks - 1 else False ,
7979 conv_layer = self .conv_layer , act_fn = self .act_fn , pool = self .pool ,
8080 zero_bn = self .zero_bn , bn_1st = self .bn_1st ,
@@ -87,7 +87,7 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
8787def _make_body (self ):
8888 stride = 2 if self .stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
8989 blocks = [(f"l_{ i } " , self ._make_layer (self , self .expansion ,
90- ni = self .block_sizes [i ], nf = self .block_sizes [i + 1 ],
90+ in_channels = self .block_sizes [i ], out_channels = self .block_sizes [i + 1 ],
9191 blocks = l , stride = stride if i == 0 else 2 ,
9292 sa = self .sa if i == 0 else False ))
9393 for i , l in enumerate (self .layers )]
0 commit comments