11from collections import OrderedDict
22from functools import partial
3- from typing import Callable , List , Sequence , Union
3+ from typing import Callable , List , Type , Union
44
55import torch .nn as nn
66
1616def init_cnn (module : nn .Module ):
1717 "Init module - kaiming_normal for Conv2d and 0 for biases."
1818 if getattr (module , 'bias' , None ) is not None :
19- nn .init .constant_ (module .bias , 0 )
19+ nn .init .constant_ (module .bias , 0 ) # type: ignore
2020 if isinstance (module , (nn .Conv2d , nn .Linear )):
2121 nn .init .kaiming_normal_ (module .weight )
2222 for layer in module .children ():
@@ -32,7 +32,7 @@ def __init__(
3232 in_channels : int ,
3333 mid_channels : int ,
3434 stride : int = 1 ,
35- conv_layer : Union [ nn . Module , nn . Sequential ] = ConvBnAct ,
35+ conv_layer = ConvBnAct ,
3636 act_fn : nn .Module = act_fn ,
3737 zero_bn : bool = True ,
3838 bn_1st : bool = True ,
@@ -49,7 +49,7 @@ def __init__(
4949 if div_groups is not None : # check if groups != 1 and div_groups
5050 groups = int (mid_channels / div_groups )
5151 if expansion == 1 :
52- layers = [("conv_0" , conv_layer (in_channels , mid_channels , 3 , stride = stride ,
52+ layers = [("conv_0" , conv_layer (in_channels , mid_channels , 3 , stride = stride , # type: ignore
5353 act_fn = act_fn , bn_1st = bn_1st , groups = in_channels if dw else groups )),
5454 ("conv_1" , conv_layer (mid_channels , out_channels , 3 , zero_bn = zero_bn ,
5555 act_fn = False , bn_1st = bn_1st , groups = mid_channels if dw else groups ))
@@ -99,7 +99,8 @@ def _make_stem(self):
9999
100100def _make_layer (self , layer_num : int ) -> nn .Module :
101101 # expansion, in_channels, out_channels, blocks, stride, sa):
102- stride = 1 if self .stem_pool and layer_num == 0 else 2 # if no pool on stem - stride = 2 for first layer block in body
102+ # if no pool on stem - stride = 2 for first layer block in body
103+ stride = 1 if self .stem_pool and layer_num == 0 else 2
103104 num_blocks = self .layers [layer_num ]
104105 return nn .Sequential (OrderedDict ([
105106 (f"bl_{ block_num } " , self .block (
@@ -144,22 +145,22 @@ def __init__(
144145 conv_layer = ConvBnAct ,
145146 block_sizes : List [int ] = [64 , 128 , 256 , 512 ],
146147 layers : List [int ] = [2 , 2 , 2 , 2 ],
147- norm : nn .Module = nn .BatchNorm2d ,
148+ norm : Type [ nn .Module ] = nn .BatchNorm2d ,
148149 act_fn : nn .Module = nn .ReLU (inplace = True ),
149150 pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True ),
150151 expansion : int = 1 ,
151152 groups : int = 1 ,
152153 dw : bool = False ,
153- div_groups : Union [int , None ]= None ,
154- sa : Union [bool , int , Callable ] = False ,
155- se : Union [bool , int , Callable ] = False ,
154+ div_groups : Union [int , None ] = None ,
155+ sa : Union [bool , int , Type [ nn . Module ] ] = False ,
156+ se : Union [bool , int , Type [ nn . Module ] ] = False ,
156157 se_module = None ,
157158 se_reduction = None ,
158159 bn_1st : bool = True ,
159160 zero_bn : bool = True ,
160161 stem_stride_on : int = 0 ,
161162 stem_sizes : List [int ] = [32 , 32 , 64 ],
162- stem_pool : Union [nn .Module , None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
163+ stem_pool : Union [Type [ nn .Module ] , None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ), # type: ignore
163164 stem_bn_end : bool = False ,
164165 _init_cnn : Callable = init_cnn ,
165166 _make_stem : Callable = _make_stem ,
@@ -172,24 +173,54 @@ def __init__(
172173 # se_module - deprecated. Leaved for warning and checks.
173174 # if stem_pool is False - no pool at stem
174175
175- params = locals ()
176- del params ['self' ]
177- self .__dict__ = params
178-
179- self ._block_sizes = params ['block_sizes' ]
176+ self .name = name
177+ self .in_chans = in_chans
178+ self .num_classes = num_classes
179+ self .block = block
180+ self .conv_layer = conv_layer
181+ self ._block_sizes = block_sizes
182+ self .layers = layers
183+ self .norm = norm
184+ self .act_fn = act_fn
185+ self .pool = pool
186+ self .expansion = expansion
187+ self .groups = groups
188+ self .dw = dw
189+ self .div_groups = div_groups
190+ # se_module
191+ # se_reduction
192+ self .bn_1st = bn_1st
193+ self .zero_bn = zero_bn
194+ self .stem_stride_on = stem_stride_on
195+ self .stem_pool = stem_pool
196+ self .stem_bn_end = stem_bn_end
197+ self ._init_cnn = _init_cnn
198+ self ._make_stem = _make_stem
199+ self ._make_layer = _make_layer
200+ self ._make_body = _make_body
201+ self ._make_head = _make_head
202+
203+ # params = locals()
204+ # del params['self']
205+ # self.__dict__ = params
206+
207+ # self._block_sizes = params['block_sizes']
208+ self .stem_sizes = stem_sizes
180209 if self .stem_sizes [0 ] != self .in_chans :
181210 self .stem_sizes = [self .in_chans ] + self .stem_sizes
211+ self .se = se
182212 if self .se :
183213 if type (self .se ) in (bool , int ): # if se=1 or se=True
184214 self .se = SEModule
185215 else :
186216 self .se = se # TODO add check issubclass or isinstance of nn.Module
217+ self .sa = sa
187218 if self .sa : # if sa=1 or sa=True
188219 if type (self .sa ) in (bool , int ):
189220 self .sa = SimpleSelfAttention # default: ks=1, sym=sym
190221 else :
191222 self .sa = sa
192- if self . se_module or se_reduction : # pragma: no cover
223+ if se_module or se_reduction : # pragma: no cover
193224 print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." ) # add deprecation warning.
194225
195226 @property
0 commit comments