Skip to content

Commit 36bdd35

Browse files
authored
Merge pull request #62 from ayasyrev/cfg
Cfg
2 parents 2832f38 + 4757237 commit 36bdd35

File tree

8 files changed

+1461
-471
lines changed

8 files changed

+1461
-471
lines changed

Nbs/00_ModelConstructor.ipynb

Lines changed: 416 additions & 45 deletions
Large diffs are not rendered by default.

Nbs/index.ipynb

Lines changed: 310 additions & 102 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 169 additions & 121 deletions
Large diffs are not rendered by default.

docs/00_ModelConstructor.md

Lines changed: 302 additions & 10 deletions
Large diffs are not rendered by default.

docs/index.md

Lines changed: 145 additions & 50 deletions
Large diffs are not rendered by default.

src/model_constructor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from model_constructor.convmixer import ConvMixer # noqa F401
2-
from model_constructor.model_constructor import ModelConstructor, ResBlock # noqa F401
2+
from model_constructor.model_constructor import ModelConstructor, ResBlock, CfgMC # noqa F401
33

44
from model_constructor.version import __version__ # noqa F401
Lines changed: 113 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from dataclasses import dataclass, field, asdict
2+
13
from collections import OrderedDict
2-
from functools import partial
3-
from typing import Callable, List, Type, Union
4+
# from functools import partial
5+
from typing import Callable, List, Optional, Type, Union
46

57
import torch.nn as nn
68

@@ -12,24 +14,14 @@
1214
"act_fn",
1315
"ResBlock",
1416
"ModelConstructor",
15-
"xresnet34",
16-
"xresnet50",
17+
# "xresnet34",
18+
# "xresnet50",
1719
]
1820

1921

2022
act_fn = nn.ReLU(inplace=True)
2123

2224

23-
def init_cnn(module: nn.Module):
24-
"Init module - kaiming_normal for Conv2d and 0 for biases."
25-
if getattr(module, "bias", None) is not None:
26-
nn.init.constant_(module.bias, 0) # type: ignore
27-
if isinstance(module, (nn.Conv2d, nn.Linear)):
28-
nn.init.kaiming_normal_(module.weight)
29-
for layer in module.children():
30-
init_cnn(layer)
31-
32-
3325
class ResBlock(nn.Module):
3426
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
3527

@@ -130,10 +122,55 @@ def forward(self, x):
130122
return self.act_fn(self.convs(x) + identity)
131123

132124

133-
def _make_stem(self):
134-
stem = [
125+
@dataclass
126+
class CfgMC:
127+
"""Model constructor Config. As default - xresnet18"""
128+
129+
name: str = "MC"
130+
in_chans: int = 3
131+
num_classes: int = 1000
132+
block: Type[nn.Module] = ResBlock
133+
conv_layer: Type[nn.Module] = ConvBnAct
134+
block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
135+
layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
136+
norm: Type[nn.Module] = nn.BatchNorm2d
137+
act_fn: nn.Module = nn.ReLU(inplace=True)
138+
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
139+
expansion: int = 1
140+
groups: int = 1
141+
dw: bool = False
142+
div_groups: Union[int, None] = None
143+
sa: Union[bool, int, Type[nn.Module]] = False
144+
se: Union[bool, int, Type[nn.Module]] = False
145+
se_module = None
146+
se_reduction = None
147+
bn_1st: bool = True
148+
zero_bn: bool = True
149+
stem_stride_on: int = 0
150+
stem_sizes: List[int] = field(default_factory=lambda: [32, 32, 64])
151+
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
152+
stem_bn_end: bool = False
153+
_init_cnn: Optional[Callable[[nn.Module], None]] = field(repr=False, default=None)
154+
_make_stem: Optional[Callable] = field(repr=False, default=None)
155+
_make_layer: Optional[Callable] = field(repr=False, default=None)
156+
_make_body: Optional[Callable] = field(repr=False, default=None)
157+
_make_head: Optional[Callable] = field(repr=False, default=None)
158+
159+
160+
def init_cnn(module: nn.Module):
161+
"Init module - kaiming_normal for Conv2d and 0 for biases."
162+
if getattr(module, "bias", None) is not None:
163+
nn.init.constant_(module.bias, 0) # type: ignore
164+
if isinstance(module, (nn.Conv2d, nn.Linear)):
165+
nn.init.kaiming_normal_(module.weight)
166+
for layer in module.children():
167+
init_cnn(layer)
168+
169+
170+
def make_stem(self: CfgMC) -> nn.Sequential:
171+
stem: List[tuple[str, nn.Module]] = [
135172
(f"conv_{i}", self.conv_layer(
136-
self.stem_sizes[i],
173+
self.stem_sizes[i], # type: ignore
137174
self.stem_sizes[i + 1],
138175
stride=2 if i == self.stem_stride_on else 1,
139176
bn_layer=(not self.stem_bn_end)
@@ -147,39 +184,38 @@ def _make_stem(self):
147184
if self.stem_pool:
148185
stem.append(("stem_pool", self.stem_pool))
149186
if self.stem_bn_end:
150-
stem.append(("norm", self.norm(self.stem_sizes[-1])))
187+
stem.append(("norm", self.norm(self.stem_sizes[-1]))) # type: ignore
151188
return nn.Sequential(OrderedDict(stem))
152189

153190

154-
def _make_layer(self, layer_num: int) -> nn.Module:
191+
def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
155192
# expansion, in_channels, out_channels, blocks, stride, sa):
156193
# if no pool on stem - stride = 2 for first layer block in body
157-
stride = 1 if self.stem_pool and layer_num == 0 else 2
158-
num_blocks = self.layers[layer_num]
194+
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
195+
num_blocks = cfg.layers[layer_num]
196+
block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
159197
return nn.Sequential(
160198
OrderedDict(
161199
[
162200
(
163201
f"bl_{block_num}",
164-
self.block(
165-
self.expansion,
166-
self.block_sizes[layer_num]
167-
if block_num == 0
168-
else self.block_sizes[layer_num + 1],
169-
self.block_sizes[layer_num + 1],
202+
cfg.block(
203+
cfg.expansion, # type: ignore
204+
block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
205+
block_chs[layer_num + 1],
170206
stride if block_num == 0 else 1,
171-
sa=self.sa
207+
sa=cfg.sa
172208
if (block_num == num_blocks - 1) and layer_num == 0
173209
else None,
174-
conv_layer=self.conv_layer,
175-
act_fn=self.act_fn,
176-
pool=self.pool,
177-
zero_bn=self.zero_bn,
178-
bn_1st=self.bn_1st,
179-
groups=self.groups,
180-
div_groups=self.div_groups,
181-
dw=self.dw,
182-
se=self.se,
210+
conv_layer=cfg.conv_layer,
211+
act_fn=cfg.act_fn,
212+
pool=cfg.pool,
213+
zero_bn=cfg.zero_bn,
214+
bn_1st=cfg.bn_1st,
215+
groups=cfg.groups,
216+
div_groups=cfg.div_groups,
217+
dw=cfg.dw,
218+
se=cfg.se,
183219
),
184220
)
185221
for block_num in range(num_blocks)
@@ -188,160 +224,96 @@ def _make_layer(self, layer_num: int) -> nn.Module:
188224
)
189225

190226

191-
def _make_body(self):
227+
def make_body(cfg: CfgMC) -> nn.Sequential:
192228
return nn.Sequential(
193229
OrderedDict(
194230
[
195231
(
196232
f"l_{layer_num}",
197-
self._make_layer(self, layer_num)
233+
cfg._make_layer(cfg, layer_num) # type: ignore
198234
)
199-
for layer_num in range(len(self.layers))
235+
for layer_num in range(len(cfg.layers))
200236
]
201237
)
202238
)
203239

204240

205-
def _make_head(self):
241+
def make_head(cfg: CfgMC) -> nn.Sequential:
206242
head = [
207243
("pool", nn.AdaptiveAvgPool2d(1)),
208244
("flat", nn.Flatten()),
209-
("fc", nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes)),
245+
("fc", nn.Linear(cfg.block_sizes[-1] * cfg.expansion, cfg.num_classes)),
210246
]
211247
return nn.Sequential(OrderedDict(head))
212248

213249

214-
class ModelConstructor:
250+
@dataclass
251+
class ModelConstructor(CfgMC):
215252
"""Model constructor. As default - xresnet18"""
216253

217-
def __init__(
218-
self,
219-
name: str = "MC",
220-
in_chans: int = 3,
221-
num_classes: int = 1000,
222-
block=ResBlock,
223-
conv_layer=ConvBnAct,
224-
block_sizes: List[int] = [64, 128, 256, 512],
225-
layers: List[int] = [2, 2, 2, 2],
226-
norm: Type[nn.Module] = nn.BatchNorm2d,
227-
act_fn: nn.Module = nn.ReLU(inplace=True),
228-
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True),
229-
expansion: int = 1,
230-
groups: int = 1,
231-
dw: bool = False,
232-
div_groups: Union[int, None] = None,
233-
sa: Union[bool, int, Type[nn.Module]] = False,
234-
se: Union[bool, int, Type[nn.Module]] = False,
235-
se_module=None,
236-
se_reduction=None,
237-
bn_1st: bool = True,
238-
zero_bn: bool = True,
239-
stem_stride_on: int = 0,
240-
stem_sizes: List[int] = [32, 32, 64],
241-
stem_pool: Union[Type[nn.Module], None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # type: ignore
242-
stem_bn_end: bool = False,
243-
_init_cnn: Callable = init_cnn,
244-
_make_stem: Callable = _make_stem,
245-
_make_layer: Callable = _make_layer,
246-
_make_body: Callable = _make_body,
247-
_make_head: Callable = _make_head,
248-
):
249-
super().__init__()
250-
# se can be bool, int (0, 1) or nn.Module
251-
# se_module - deprecated. Leaved for warning and checks.
252-
# if stem_pool is False - no pool at stem
253-
254-
self.name = name
255-
self.in_chans = in_chans
256-
self.num_classes = num_classes
257-
self.block = block
258-
self.conv_layer = conv_layer
259-
self._block_sizes = block_sizes
260-
self.layers = layers
261-
self.norm = norm
262-
self.act_fn = act_fn
263-
self.pool = pool
264-
self.expansion = expansion
265-
self.groups = groups
266-
self.dw = dw
267-
self.div_groups = div_groups
268-
# se_module
269-
# se_reduction
270-
self.bn_1st = bn_1st
271-
self.zero_bn = zero_bn
272-
self.stem_stride_on = stem_stride_on
273-
self.stem_pool = stem_pool
274-
self.stem_bn_end = stem_bn_end
275-
self._init_cnn = _init_cnn
276-
self._make_stem = _make_stem
277-
self._make_layer = _make_layer
278-
self._make_body = _make_body
279-
self._make_head = _make_head
280-
281-
# params = locals()
282-
# del params['self']
283-
# self.__dict__ = params
284-
285-
# self._block_sizes = params['block_sizes']
286-
self.stem_sizes = stem_sizes
254+
def __post_init__(self):
255+
if self._init_cnn is None:
256+
self._init_cnn = init_cnn
257+
if self._make_stem is None:
258+
self._make_stem = make_stem
259+
if self._make_layer is None:
260+
self._make_layer = make_layer
261+
if self._make_body is None:
262+
self._make_body = make_body
263+
if self._make_head is None:
264+
self._make_head = make_head
265+
287266
if self.stem_sizes[0] != self.in_chans:
288267
self.stem_sizes = [self.in_chans] + self.stem_sizes
289-
self.se = se
290-
if self.se:
291-
if type(self.se) in (bool, int): # if se=1 or se=True
292-
self.se = SEModule
293-
else:
294-
self.se = se # TODO add check issubclass or isinstance of nn.Module
295-
self.sa = sa
296-
if self.sa: # if sa=1 or sa=True
297-
if type(self.sa) in (bool, int):
298-
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
299-
else:
300-
self.sa = sa
301-
if se_module or se_reduction: # pragma: no cover
268+
if self.se and isinstance(self.se, (bool, int)): # if se=1 or se=True
269+
self.se = SEModule
270+
if self.sa and isinstance(self.sa, (bool, int)): # if sa=1 or sa=True
271+
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
272+
if self.se_module or self.se_reduction: # pragma: no cover
302273
print(
303274
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
304275
) # add deprecation warning.
305276

306-
@property
307-
def block_sizes(self):
308-
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
309-
310277
@property
311278
def stem(self):
312-
return self._make_stem(self)
279+
return self._make_stem(self) # type: ignore
313280

314281
@property
315282
def head(self):
316-
return self._make_head(self)
283+
return self._make_head(self) # type: ignore
317284

318285
@property
319286
def body(self):
320-
return self._make_body(self)
287+
return self._make_body(self) # type: ignore
288+
289+
@classmethod
290+
def from_cfg(cls, cfg: CfgMC):
291+
return cls(**asdict(cfg))
321292

322293
def __call__(self):
323294
model = nn.Sequential(
324295
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
325296
)
326-
self._init_cnn(model)
297+
self._init_cnn(model) # type: ignore
327298
model.extra_repr = lambda: f"{self.name}"
328299
return model
329300

330-
def __repr__(self):
331-
return (
301+
def print_cfg(self):
302+
print(
332303
f"{self.name} constructor\n"
333304
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
334305
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
335306
f" sa: {self.sa}, se: {self.se}\n"
336307
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
337-
f" body sizes {self._block_sizes}\n"
308+
f" body sizes {self.block_sizes}\n"
338309
f" layers: {self.layers}"
339310
)
340311

341312

342-
xresnet34 = partial(
343-
ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
313+
xresnet34 = ModelConstructor.from_cfg(
314+
CfgMC(name="xresnet34", expansion=1, layers=[3, 4, 6, 3])
344315
)
345-
xresnet50 = partial(
346-
ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
316+
317+
xresnet50 = ModelConstructor.from_cfg(
318+
CfgMC(name="xresnet34", expansion=4, layers=[3, 4, 6, 3])
347319
)

tests/test_mc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ def test_MC():
1111
"""test ModelConstructor"""
1212
img_size = 16
1313
mc = ModelConstructor()
14-
assert "MC constructor" in str(mc)
14+
assert "name='MC'" in str(mc)
1515
model = mc()
1616
xb = torch.randn(bs_test, 3, img_size, img_size)
1717
pred = model(xb)
1818
assert pred.shape == torch.Size([bs_test, 1000])
19+
mc.expansion = 2
20+
model = mc()
21+
pred = model(xb)
22+
assert pred.shape == torch.Size([bs_test, 1000])
1923
num_classes = 10
2024
mc.num_classes = num_classes
2125
mc.se = SEModule

0 commit comments

Comments
 (0)