Skip to content

Commit c73d6f6

Browse files
committed
block_sizes, repr, tests
1 parent d5cc2f5 commit c73d6f6

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

src/model_constructor/model_constructor.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class ModelConstructorCfg:
131131
num_classes: int = 1000
132132
block: Type[nn.Module] = ResBlock
133133
conv_layer: Type[nn.Module] = ConvBnAct
134-
_block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
134+
block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
135135
layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
136136
norm: Type[nn.Module] = nn.BatchNorm2d
137137
act_fn: nn.Module = nn.ReLU(inplace=True)
@@ -150,11 +150,11 @@ class ModelConstructorCfg:
150150
stem_sizes: List[int] = field(default_factory=lambda: [32, 32, 64])
151151
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
152152
stem_bn_end: bool = False
153-
_init_cnn: Optional[Callable[[nn.Module], None]] = None
154-
_make_stem: Optional[Callable] = None
155-
_make_layer: Optional[Callable] = None
156-
_make_body: Optional[Callable] = None
157-
_make_head: Optional[Callable] = None
153+
_init_cnn: Optional[Callable[[nn.Module], None]] = field(repr=False, default=None)
154+
_make_stem: Optional[Callable] = field(repr=False, default=None)
155+
_make_layer: Optional[Callable] = field(repr=False, default=None)
156+
_make_body: Optional[Callable] = field(repr=False, default=None)
157+
_make_head: Optional[Callable] = field(repr=False, default=None)
158158

159159

160160
def init_cnn(module: nn.Module):
@@ -188,22 +188,21 @@ def _make_stem(self: ModelConstructorCfg) -> nn.Sequential:
188188
return nn.Sequential(OrderedDict(stem))
189189

190190

191-
def _make_layer(self, layer_num: int) -> nn.Module:
191+
def _make_layer(self: ModelConstructorCfg, layer_num: int) -> nn.Sequential:
192192
# expansion, in_channels, out_channels, blocks, stride, sa):
193193
# if no pool on stem - stride = 2 for first layer block in body
194194
stride = 1 if self.stem_pool and layer_num == 0 else 2
195195
num_blocks = self.layers[layer_num]
196+
block_chs = [self.stem_sizes[-1] // self.expansion] + self.block_sizes
196197
return nn.Sequential(
197198
OrderedDict(
198199
[
199200
(
200201
f"bl_{block_num}",
201202
self.block(
202-
self.expansion,
203-
self.block_sizes[layer_num]
204-
if block_num == 0
205-
else self.block_sizes[layer_num + 1],
206-
self.block_sizes[layer_num + 1],
203+
self.expansion, # type: ignore
204+
block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
205+
block_chs[layer_num + 1],
207206
stride if block_num == 0 else 1,
208207
sa=self.sa
209208
if (block_num == num_blocks - 1) and layer_num == 0
@@ -225,7 +224,7 @@ def _make_layer(self, layer_num: int) -> nn.Module:
225224
)
226225

227226

228-
def _make_body(self):
227+
def _make_body(self: ModelConstructorCfg) -> nn.Sequential:
229228
return nn.Sequential(
230229
OrderedDict(
231230
[
@@ -275,9 +274,9 @@ def __post_init__(self):
275274
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
276275
) # add deprecation warning.
277276

278-
@property
279-
def block_sizes(self):
280-
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
277+
# @property
278+
# def block_sizes(self):
279+
# return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
281280

282281
@property
283282
def stem(self):
@@ -303,18 +302,17 @@ def __call__(self):
303302
model.extra_repr = lambda: f"{self.name}"
304303
return model
305304

306-
def __repr__(self):
307-
return (
305+
def simple_cfg(self):
306+
print(
308307
f"{self.name} constructor\n"
309308
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
310309
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
311310
f" sa: {self.sa}, se: {self.se}\n"
312311
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
313-
f" body sizes {self._block_sizes}\n"
312+
f" body sizes {self.block_sizes}\n"
314313
f" layers: {self.layers}"
315314
)
316315

317-
318316
# xresnet34 = partial(
319317
# ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
320318
# )

tests/test_mc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ def test_MC():
1111
"""test ModelConstructor"""
1212
img_size = 16
1313
mc = ModelConstructor()
14-
assert "MC constructor" in str(mc)
14+
assert "name='MC'" in str(mc)
1515
model = mc()
1616
xb = torch.randn(bs_test, 3, img_size, img_size)
1717
pred = model(xb)
1818
assert pred.shape == torch.Size([bs_test, 1000])
19+
mc.expansion = 2
20+
model = mc()
21+
pred = model(xb)
22+
assert pred.shape == torch.Size([bs_test, 1000])
1923
num_classes = 10
2024
mc.num_classes = num_classes
2125
mc.se = SEModule

0 commit comments

Comments
 (0)