@@ -117,8 +117,8 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
117117 act_fn = nn .ReLU (inplace = True ),
118118 pool = nn .AvgPool2d (2 , ceil_mode = True ),
119119 expansion = 1 , groups = 1 , dw = False , div_groups = None ,
120- sa = False ,
121- se : Union [bool , Callable ] = False , # se can be bool or nn.Module
120+ sa : Union [ bool , int , Callable ] = False ,
121+ se : Union [bool , int , Callable ] = False , # se can be bool, int (0, 1) or nn.Module
122122 se_module = None , se_reduction = None , # deprecated. Leaved for worning and checks.
123123 bn_1st = True ,
124124 zero_bn = True ,
@@ -142,10 +142,10 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
142142 if self .stem_sizes [0 ] != self .c_in :
143143 self .stem_sizes = [self .c_in ] + self .stem_sizes
144144 if self .se : # TODO add check issubclass or isinstance of nn.Module
145- if type (self .se ) == bool :
146- self .se = SEModule # if se=1
147- if self .sa :
148- if type (self .sa ) == bool :
145+ if type (self .se ) in ( bool , int ): # if se=1 or se=True
146+ self .se = SEModule
147+ if self .sa : # if sa=1 or sa=True
148+ if type (self .sa ) in ( bool , int ) :
149149 self .sa = SimpleSelfAttention # default: ks=1, sym=sym
150150 if self .se_module or se_reduction :
151151 print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." ) # add deprecation worning.
0 commit comments