Skip to content

Commit 4f41bef

Browse files
authored
Merge pull request #49 from ayasyrev/se_sa_fix
fix se, sa settings at MC
2 parents b00815a + f18730f commit 4f41bef

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

model_constructor/model_constructor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,16 @@ def __init__(self, name='MC', in_chans=3, num_classes=1000,
150150
self._block_sizes = params['block_sizes']
151151
if self.stem_sizes[0] != self.in_chans:
152152
self.stem_sizes = [self.in_chans] + self.stem_sizes
153-
if self.se: # TODO add check issubclass or isinstance of nn.Module
153+
if self.se:
154154
if type(self.se) in (bool, int): # if se=1 or se=True
155155
self.se = SEModule
156+
else:
157+
self.se = se # TODO add check issubclass or isinstance of nn.Module
156158
if self.sa: # if sa=1 or sa=True
157159
if type(self.sa) in (bool, int):
158160
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
161+
else:
162+
self.sa = sa
159163
if self.se_module or se_reduction: # pragma: no cover
160164
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation worning.
161165

tests/test_mc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from model_constructor import ModelConstructor
4-
from model_constructor.layers import SEModule, SimpleSelfAttention
4+
from model_constructor.layers import SEModule, SEModuleConv, SimpleSelfAttention
55

66

77
bs_test = 4
@@ -30,3 +30,9 @@ def test_MC():
3030
model = mc()
3131
pred = model(xb)
3232
assert pred.shape == torch.Size([bs_test, num_classes])
33+
mc = ModelConstructor(sa=SimpleSelfAttention, se=SEModuleConv, num_classes=num_classes)
34+
assert mc.se is SEModuleConv
35+
assert mc.sa is SimpleSelfAttention
36+
model = mc()
37+
pred = model(xb)
38+
assert pred.shape == torch.Size([bs_test, num_classes])

0 commit comments

Comments
 (0)