Skip to content

Commit d5cc2f5

Browse files
committed
add dataclass cfg
1 parent 2832f38 commit d5cc2f5

File tree

1 file changed

+85
-109
lines changed

1 file changed

+85
-109
lines changed

src/model_constructor/model_constructor.py

Lines changed: 85 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from dataclasses import dataclass, field, asdict
2+
13
from collections import OrderedDict
2-
from functools import partial
3-
from typing import Callable, List, Type, Union
4+
# from functools import partial
5+
from typing import Callable, List, Optional, Type, Union
46

57
import torch.nn as nn
68

@@ -12,24 +14,14 @@
1214
"act_fn",
1315
"ResBlock",
1416
"ModelConstructor",
15-
"xresnet34",
16-
"xresnet50",
17+
# "xresnet34",
18+
# "xresnet50",
1719
]
1820

1921

2022
act_fn = nn.ReLU(inplace=True)
2123

2224

23-
def init_cnn(module: nn.Module):
24-
"Init module - kaiming_normal for Conv2d and 0 for biases."
25-
if getattr(module, "bias", None) is not None:
26-
nn.init.constant_(module.bias, 0) # type: ignore
27-
if isinstance(module, (nn.Conv2d, nn.Linear)):
28-
nn.init.kaiming_normal_(module.weight)
29-
for layer in module.children():
30-
init_cnn(layer)
31-
32-
3325
class ResBlock(nn.Module):
3426
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
3527

@@ -130,10 +122,55 @@ def forward(self, x):
130122
return self.act_fn(self.convs(x) + identity)
131123

132124

133-
def _make_stem(self):
134-
stem = [
125+
@dataclass
126+
class ModelConstructorCfg:
127+
"""Model constructor. As default - xresnet18"""
128+
129+
name: str = "MC"
130+
in_chans: int = 3
131+
num_classes: int = 1000
132+
block: Type[nn.Module] = ResBlock
133+
conv_layer: Type[nn.Module] = ConvBnAct
134+
_block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
135+
layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
136+
norm: Type[nn.Module] = nn.BatchNorm2d
137+
act_fn: nn.Module = nn.ReLU(inplace=True)
138+
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
139+
expansion: int = 1
140+
groups: int = 1
141+
dw: bool = False
142+
div_groups: Union[int, None] = None
143+
sa: Union[bool, int, Type[nn.Module]] = False
144+
se: Union[bool, int, Type[nn.Module]] = False
145+
se_module = None
146+
se_reduction = None
147+
bn_1st: bool = True
148+
zero_bn: bool = True
149+
stem_stride_on: int = 0
150+
stem_sizes: List[int] = field(default_factory=lambda: [32, 32, 64])
151+
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
152+
stem_bn_end: bool = False
153+
_init_cnn: Optional[Callable[[nn.Module], None]] = None
154+
_make_stem: Optional[Callable] = None
155+
_make_layer: Optional[Callable] = None
156+
_make_body: Optional[Callable] = None
157+
_make_head: Optional[Callable] = None
158+
159+
160+
def init_cnn(module: nn.Module):
161+
"Init module - kaiming_normal for Conv2d and 0 for biases."
162+
if getattr(module, "bias", None) is not None:
163+
nn.init.constant_(module.bias, 0) # type: ignore
164+
if isinstance(module, (nn.Conv2d, nn.Linear)):
165+
nn.init.kaiming_normal_(module.weight)
166+
for layer in module.children():
167+
init_cnn(layer)
168+
169+
170+
def _make_stem(self: ModelConstructorCfg) -> nn.Sequential:
171+
stem: List[tuple[str, nn.Module]] = [
135172
(f"conv_{i}", self.conv_layer(
136-
self.stem_sizes[i],
173+
self.stem_sizes[i], # type: ignore
137174
self.stem_sizes[i + 1],
138175
stride=2 if i == self.stem_stride_on else 1,
139176
bn_layer=(not self.stem_bn_end)
@@ -147,7 +184,7 @@ def _make_stem(self):
147184
if self.stem_pool:
148185
stem.append(("stem_pool", self.stem_pool))
149186
if self.stem_bn_end:
150-
stem.append(("norm", self.norm(self.stem_sizes[-1])))
187+
stem.append(("norm", self.norm(self.stem_sizes[-1]))) # type: ignore
151188
return nn.Sequential(OrderedDict(stem))
152189

153190

@@ -202,7 +239,7 @@ def _make_body(self):
202239
)
203240

204241

205-
def _make_head(self):
242+
def _make_head(self: ModelConstructorCfg) -> nn.Sequential:
206243
head = [
207244
("pool", nn.AdaptiveAvgPool2d(1)),
208245
("flat", nn.Flatten()),
@@ -211,94 +248,29 @@ def _make_head(self):
211248
return nn.Sequential(OrderedDict(head))
212249

213250

214-
class ModelConstructor:
251+
@dataclass
252+
class ModelConstructor(ModelConstructorCfg):
215253
"""Model constructor. As default - xresnet18"""
216254

217-
def __init__(
218-
self,
219-
name: str = "MC",
220-
in_chans: int = 3,
221-
num_classes: int = 1000,
222-
block=ResBlock,
223-
conv_layer=ConvBnAct,
224-
block_sizes: List[int] = [64, 128, 256, 512],
225-
layers: List[int] = [2, 2, 2, 2],
226-
norm: Type[nn.Module] = nn.BatchNorm2d,
227-
act_fn: nn.Module = nn.ReLU(inplace=True),
228-
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True),
229-
expansion: int = 1,
230-
groups: int = 1,
231-
dw: bool = False,
232-
div_groups: Union[int, None] = None,
233-
sa: Union[bool, int, Type[nn.Module]] = False,
234-
se: Union[bool, int, Type[nn.Module]] = False,
235-
se_module=None,
236-
se_reduction=None,
237-
bn_1st: bool = True,
238-
zero_bn: bool = True,
239-
stem_stride_on: int = 0,
240-
stem_sizes: List[int] = [32, 32, 64],
241-
stem_pool: Union[Type[nn.Module], None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # type: ignore
242-
stem_bn_end: bool = False,
243-
_init_cnn: Callable = init_cnn,
244-
_make_stem: Callable = _make_stem,
245-
_make_layer: Callable = _make_layer,
246-
_make_body: Callable = _make_body,
247-
_make_head: Callable = _make_head,
248-
):
249-
super().__init__()
250-
# se can be bool, int (0, 1) or nn.Module
251-
# se_module - deprecated. Leaved for warning and checks.
252-
# if stem_pool is False - no pool at stem
253-
254-
self.name = name
255-
self.in_chans = in_chans
256-
self.num_classes = num_classes
257-
self.block = block
258-
self.conv_layer = conv_layer
259-
self._block_sizes = block_sizes
260-
self.layers = layers
261-
self.norm = norm
262-
self.act_fn = act_fn
263-
self.pool = pool
264-
self.expansion = expansion
265-
self.groups = groups
266-
self.dw = dw
267-
self.div_groups = div_groups
268-
# se_module
269-
# se_reduction
270-
self.bn_1st = bn_1st
271-
self.zero_bn = zero_bn
272-
self.stem_stride_on = stem_stride_on
273-
self.stem_pool = stem_pool
274-
self.stem_bn_end = stem_bn_end
275-
self._init_cnn = _init_cnn
276-
self._make_stem = _make_stem
277-
self._make_layer = _make_layer
278-
self._make_body = _make_body
279-
self._make_head = _make_head
280-
281-
# params = locals()
282-
# del params['self']
283-
# self.__dict__ = params
284-
285-
# self._block_sizes = params['block_sizes']
286-
self.stem_sizes = stem_sizes
255+
def __post_init__(self):
256+
if self._init_cnn is None:
257+
self._init_cnn = init_cnn
258+
if self._make_stem is None:
259+
self._make_stem = _make_stem
260+
if self._make_layer is None:
261+
self._make_layer = _make_layer
262+
if self._make_body is None:
263+
self._make_body = _make_body
264+
if self._make_head is None:
265+
self._make_head = _make_head
266+
287267
if self.stem_sizes[0] != self.in_chans:
288268
self.stem_sizes = [self.in_chans] + self.stem_sizes
289-
self.se = se
290-
if self.se:
291-
if type(self.se) in (bool, int): # if se=1 or se=True
292-
self.se = SEModule
293-
else:
294-
self.se = se # TODO add check issubclass or isinstance of nn.Module
295-
self.sa = sa
296-
if self.sa: # if sa=1 or sa=True
297-
if type(self.sa) in (bool, int):
298-
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
299-
else:
300-
self.sa = sa
301-
if se_module or se_reduction: # pragma: no cover
269+
if self.se and isinstance(self.se, (bool, int)): # if se=1 or se=True
270+
self.se = SEModule
271+
if self.sa and isinstance(self.sa, (bool, int)): # if sa=1 or sa=True
272+
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
273+
if self.se_module or self.se_reduction: # pragma: no cover
302274
print(
303275
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
304276
) # add deprecation warning.
@@ -319,6 +291,10 @@ def head(self):
319291
def body(self):
320292
return self._make_body(self)
321293

294+
@classmethod
295+
def from_cfg(cls, cfg: ModelConstructorCfg):
296+
return cls(**asdict(cfg))
297+
322298
def __call__(self):
323299
model = nn.Sequential(
324300
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
@@ -339,9 +315,9 @@ def __repr__(self):
339315
)
340316

341317

342-
xresnet34 = partial(
343-
ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
344-
)
345-
xresnet50 = partial(
346-
ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
347-
)
318+
# xresnet34 = partial(
319+
# ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
320+
# )
321+
# xresnet50 = partial(
322+
# ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
323+
# )

0 commit comments

Comments
 (0)