Skip to content

Commit eb216ec

Browse files
committed
fix MC args
1 parent 64186e1 commit eb216ec

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

model_constructor/model_constructor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
117117
act_fn=nn.ReLU(inplace=True),
118118
pool=nn.AvgPool2d(2, ceil_mode=True),
119119
expansion=1, groups=1, dw=False, div_groups=None,
120-
sa=False,
121-
se: Union[bool, Callable] = False, # se can be bool or nn.Module
120+
sa: Union[bool, int, Callable] = False,
121+
se: Union[bool, int, Callable] = False, # se can be bool, int (0, 1) or nn.Module
122122
se_module=None, se_reduction=None, # deprecated. Leaved for worning and checks.
123123
bn_1st=True,
124124
zero_bn=True,
@@ -142,10 +142,10 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
142142
if self.stem_sizes[0] != self.c_in:
143143
self.stem_sizes = [self.c_in] + self.stem_sizes
144144
if self.se: # TODO add check issubclass or isinstance of nn.Module
145-
if type(self.se) == bool:
146-
self.se = SEModule # if se=1
147-
if self.sa:
148-
if type(self.sa) == bool:
145+
if type(self.se) in (bool, int): # if se=1 or se=True
146+
self.se = SEModule
147+
if self.sa: # if sa=1 or sa=True
148+
if type(self.sa) in (bool, int):
149149
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
150150
if self.se_module or se_reduction:
151151
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation worning.

0 commit comments

Comments
 (0)