Skip to content

Commit 07297f3

Browse files
committed
dmc and cfg as pydantic basemodel
1 parent 104d32d commit 07297f3

File tree

1 file changed

+37
-41
lines changed

1 file changed

+37
-41
lines changed

src/model_constructor/model_constructor.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
from dataclasses import dataclass, field, asdict
2-
31
from collections import OrderedDict
4-
# from functools import partial
52
from typing import Callable, List, Optional, Type, Union
63

74
import torch.nn as nn
5+
from pydantic import BaseModel
86

97
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
108

11-
129
__all__ = [
1310
"init_cnn",
1411
"act_fn",
@@ -122,17 +119,16 @@ def forward(self, x):
122119
return self.act_fn(self.convs(x) + identity)
123120

124121

125-
@dataclass
126-
class CfgMC:
122+
class CfgMC(BaseModel):
127123
"""Model constructor Config. As default - xresnet18"""
128124

129125
name: str = "MC"
130126
in_chans: int = 3
131127
num_classes: int = 1000
132128
block: Type[nn.Module] = ResBlock
133129
conv_layer: Type[nn.Module] = ConvBnAct
134-
block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
135-
layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
130+
block_sizes: List[int] = [64, 128, 256, 512]
131+
layers: List[int] = [2, 2, 2, 2]
136132
norm: Type[nn.Module] = nn.BatchNorm2d
137133
act_fn: nn.Module = nn.ReLU(inplace=True)
138134
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
@@ -142,19 +138,22 @@ class CfgMC:
142138
div_groups: Union[int, None] = None
143139
sa: Union[bool, int, Type[nn.Module]] = False
144140
se: Union[bool, int, Type[nn.Module]] = False
145-
se_module = None
146-
se_reduction = None
141+
se_module: Union[bool, None] = None
142+
se_reduction: Union[int, None] = None
147143
bn_1st: bool = True
148144
zero_bn: bool = True
149145
stem_stride_on: int = 0
150-
stem_sizes: List[int] = field(default_factory=lambda: [32, 32, 64])
146+
stem_sizes: List[int] = [32, 32, 64]
151147
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
152148
stem_bn_end: bool = False
153-
_init_cnn: Optional[Callable[[nn.Module], None]] = field(repr=False, default=None)
154-
_make_stem: Optional[Callable] = field(repr=False, default=None)
155-
_make_layer: Optional[Callable] = field(repr=False, default=None)
156-
_make_body: Optional[Callable] = field(repr=False, default=None)
157-
_make_head: Optional[Callable] = field(repr=False, default=None)
149+
init_cnn: Optional[Callable[[nn.Module], None]] = None
150+
make_stem: Optional[Callable] = None
151+
make_layer: Optional[Callable] = None
152+
make_body: Optional[Callable] = None
153+
make_head: Optional[Callable] = None
154+
155+
class Config:
156+
arbitrary_types_allowed = True
158157

159158

160159
def init_cnn(module: nn.Module):
@@ -230,7 +229,7 @@ def make_body(cfg: CfgMC) -> nn.Sequential:
230229
[
231230
(
232231
f"l_{layer_num}",
233-
cfg._make_layer(cfg, layer_num) # type: ignore
232+
cfg.make_layer(cfg, layer_num) # type: ignore
234233
)
235234
for layer_num in range(len(cfg.layers))
236235
]
@@ -247,21 +246,21 @@ def make_head(cfg: CfgMC) -> nn.Sequential:
247246
return nn.Sequential(OrderedDict(head))
248247

249248

250-
@dataclass
251249
class ModelConstructor(CfgMC):
252250
"""Model constructor. As default - xresnet18"""
253251

254-
def __post_init__(self):
255-
if self._init_cnn is None:
256-
self._init_cnn = init_cnn
257-
if self._make_stem is None:
258-
self._make_stem = make_stem
259-
if self._make_layer is None:
260-
self._make_layer = make_layer
261-
if self._make_body is None:
262-
self._make_body = make_body
263-
if self._make_head is None:
264-
self._make_head = make_head
252+
def __init__(self, **data):
253+
super().__init__(**data)
254+
if self.init_cnn is None:
255+
self.init_cnn = init_cnn
256+
if self.make_stem is None:
257+
self.make_stem = make_stem
258+
if self.make_layer is None:
259+
self.make_layer = make_layer
260+
if self.make_body is None:
261+
self.make_body = make_body
262+
if self.make_head is None:
263+
self.make_head = make_head
265264

266265
if self.stem_sizes[0] != self.in_chans:
267266
self.stem_sizes = [self.in_chans] + self.stem_sizes
@@ -276,30 +275,30 @@ def __post_init__(self):
276275

277276
@property
278277
def stem(self):
279-
return self._make_stem(self) # type: ignore
278+
return self.make_stem(self) # type: ignore
280279

281280
@property
282281
def head(self):
283-
return self._make_head(self) # type: ignore
282+
return self.make_head(self) # type: ignore
284283

285284
@property
286285
def body(self):
287-
return self._make_body(self) # type: ignore
286+
return self.make_body(self) # type: ignore
288287

289288
@classmethod
290289
def from_cfg(cls, cfg: CfgMC):
291-
return cls(**asdict(cfg))
290+
return cls(**cfg.dict())
292291

293292
def __call__(self):
294293
model = nn.Sequential(
295294
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
296295
)
297-
self._init_cnn(model) # type: ignore
296+
self.init_cnn(model) # type: ignore
298297
model.extra_repr = lambda: f"{self.name}"
299298
return model
300299

301-
def print_cfg(self):
302-
print(
300+
def __repr__(self):
301+
return (
303302
f"{self.name} constructor\n"
304303
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
305304
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
@@ -310,14 +309,11 @@ def print_cfg(self):
310309
)
311310

312311

313-
@dataclass
314312
class XResNet34(ModelConstructor):
315313
name: str = "xresnet34"
316-
layers: list[int] = field(default_factory=lambda: [3, 4, 6, 3])
314+
layers: list[int] = [3, 4, 6, 3]
317315

318316

319-
@dataclass
320-
class XResNet50(ModelConstructor):
317+
class XResNet50(XResNet34):
321318
name: str = "xresnet50"
322319
expansion: int = 4
323-
layers: list[int] = field(default_factory=lambda: [3, 4, 6, 3])

0 commit comments

Comments
 (0)