Skip to content

Commit 6c55cb0

Browse files
authored
Merge pull request #76 from ayasyrev/typing
v0.3.1 pydantic, typing, tests
2 parents c731d5a + 60b0d6c commit 6c55cb0

File tree

15 files changed

+836
-236
lines changed

15 files changed

+836
-236
lines changed

.pylintrc

Lines changed: 573 additions & 0 deletions
Large diffs are not rendered by default.

src/model_constructor/layers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
class Flatten(nn.Module):
2222
"""flat x to vector"""
2323

24-
def __init__(self):
25-
super().__init__()
26-
2724
def forward(self, x):
2825
return x.view(x.size(0), -1)
2926

@@ -36,9 +33,6 @@ def noop(x):
3633
class Noop(nn.Module):
3734
"""Dummy module"""
3835

39-
def __init__(self):
40-
super().__init__()
41-
4236
def forward(self, x):
4337
return x
4438

src/model_constructor/model_constructor.py

Lines changed: 85 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, List, Optional, Type, Union
3+
from typing import Any, Callable, List, Type, TypeVar, Union
44

55
import torch.nn as nn
66
from pydantic import BaseModel, root_validator
@@ -16,6 +16,19 @@
1616
]
1717

1818

19+
TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
20+
21+
22+
def init_cnn(module: nn.Module):
23+
"Init module - kaiming_normal for Conv2d and 0 for biases."
24+
if getattr(module, "bias", None) is not None:
25+
nn.init.constant_(module.bias, 0) # type: ignore
26+
if isinstance(module, (nn.Conv2d, nn.Linear)):
27+
nn.init.kaiming_normal_(module.weight)
28+
for layer in module.children():
29+
init_cnn(layer)
30+
31+
1932
class ResBlock(nn.Module):
2033
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2134

@@ -116,86 +129,28 @@ def forward(self, x):
116129
return self.act_fn(self.convs(x) + identity)
117130

118131

119-
class ModelCfg(BaseModel):
120-
"""Model constructor Config. As default - xresnet18"""
121-
122-
name: str = "MC"
123-
in_chans: int = 3
124-
num_classes: int = 1000
125-
block: Type[nn.Module] = ResBlock
126-
conv_layer: Type[nn.Module] = ConvBnAct
127-
block_sizes: List[int] = [64, 128, 256, 512]
128-
layers: List[int] = [2, 2, 2, 2]
129-
norm: Type[nn.Module] = nn.BatchNorm2d
130-
act_fn: Type[nn.Module] = nn.ReLU
131-
pool: Callable[[Any], nn.Module] = partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)
132-
expansion: int = 1
133-
groups: int = 1
134-
dw: bool = False
135-
div_groups: Union[int, None] = None
136-
sa: Union[bool, int, Type[nn.Module]] = False
137-
se: Union[bool, int, Type[nn.Module]] = False
138-
se_module: Union[bool, None] = None
139-
se_reduction: Union[int, None] = None
140-
bn_1st: bool = True
141-
zero_bn: bool = True
142-
stem_stride_on: int = 0
143-
stem_sizes: List[int] = [32, 32, 64]
144-
stem_pool: Union[Callable[[], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
145-
stem_bn_end: bool = False
146-
init_cnn: Optional[Callable[[nn.Module], None]] = None
147-
make_stem: Optional[Callable[["ModelCfg"], nn.Module]] = None
148-
make_layer: Optional[Callable[["ModelCfg"], nn.Module]] = None
149-
make_body: Optional[Callable[["ModelCfg"], nn.Module]] = None
150-
make_head: Optional[Callable[["ModelCfg"], nn.Module]] = None
151-
152-
class Config:
153-
arbitrary_types_allowed = True
154-
extra = "forbid"
155-
156-
def extra_repr(self) -> str:
157-
res = ""
158-
for k, v in self.dict().items():
159-
if v is not None:
160-
res += f"{k}: {v}\n"
161-
return res
162-
163-
def pprint(self) -> None:
164-
print(self.extra_repr())
165-
166-
167-
def init_cnn(module: nn.Module):
168-
"Init module - kaiming_normal for Conv2d and 0 for biases."
169-
if getattr(module, "bias", None) is not None:
170-
nn.init.constant_(module.bias, 0) # type: ignore
171-
if isinstance(module, (nn.Conv2d, nn.Linear)):
172-
nn.init.kaiming_normal_(module.weight)
173-
for layer in module.children():
174-
init_cnn(layer)
175-
176-
177-
def make_stem(self: ModelCfg) -> nn.Sequential:
132+
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
178133
stem: List[tuple[str, nn.Module]] = [
179-
(f"conv_{i}", self.conv_layer(
180-
self.stem_sizes[i], # type: ignore
181-
self.stem_sizes[i + 1],
182-
stride=2 if i == self.stem_stride_on else 1,
183-
bn_layer=(not self.stem_bn_end)
184-
if i == (len(self.stem_sizes) - 2)
134+
(f"conv_{i}", cfg.conv_layer(
135+
cfg.stem_sizes[i], # type: ignore
136+
cfg.stem_sizes[i + 1],
137+
stride=2 if i == cfg.stem_stride_on else 1,
138+
bn_layer=(not cfg.stem_bn_end)
139+
if i == (len(cfg.stem_sizes) - 2)
185140
else True,
186-
act_fn=self.act_fn,
187-
bn_1st=self.bn_1st,
141+
act_fn=cfg.act_fn,
142+
bn_1st=cfg.bn_1st,
188143
),)
189-
for i in range(len(self.stem_sizes) - 1)
144+
for i in range(len(cfg.stem_sizes) - 1)
190145
]
191-
if self.stem_pool:
192-
stem.append(("stem_pool", self.stem_pool()))
193-
if self.stem_bn_end:
194-
stem.append(("norm", self.norm(self.stem_sizes[-1]))) # type: ignore
146+
if cfg.stem_pool:
147+
stem.append(("stem_pool", cfg.stem_pool()))
148+
if cfg.stem_bn_end:
149+
stem.append(("norm", cfg.norm(cfg.stem_sizes[-1]))) # type: ignore
195150
return nn.Sequential(OrderedDict(stem))
196151

197152

198-
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential:
153+
def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
199154
# expansion, in_channels, out_channels, blocks, stride, sa):
200155
# if no pool on stem - stride = 2 for first layer block in body
201156
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
@@ -231,7 +186,7 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential:
231186
)
232187

233188

234-
def make_body(cfg: ModelCfg) -> nn.Sequential:
189+
def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
235190
return nn.Sequential(
236191
OrderedDict(
237192
[
@@ -245,7 +200,7 @@ def make_body(cfg: ModelCfg) -> nn.Sequential:
245200
)
246201

247202

248-
def make_head(cfg: ModelCfg) -> nn.Sequential:
203+
def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
249204
head = [
250205
("pool", nn.AdaptiveAvgPool2d(1)),
251206
("flat", nn.Flatten()),
@@ -254,22 +209,59 @@ def make_head(cfg: ModelCfg) -> nn.Sequential:
254209
return nn.Sequential(OrderedDict(head))
255210

256211

212+
class ModelCfg(BaseModel):
213+
"""Model constructor Config. As default - xresnet18"""
214+
215+
name: str = "MC"
216+
in_chans: int = 3
217+
num_classes: int = 1000
218+
block: Type[nn.Module] = ResBlock
219+
conv_layer: Type[nn.Module] = ConvBnAct
220+
block_sizes: List[int] = [64, 128, 256, 512]
221+
layers: List[int] = [2, 2, 2, 2]
222+
norm: Type[nn.Module] = nn.BatchNorm2d
223+
act_fn: Type[nn.Module] = nn.ReLU
224+
pool: Callable[[Any], nn.Module] = partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)
225+
expansion: int = 1
226+
groups: int = 1
227+
dw: bool = False
228+
div_groups: Union[int, None] = None
229+
sa: Union[bool, int, Type[nn.Module]] = False
230+
se: Union[bool, int, Type[nn.Module]] = False
231+
se_module: Union[bool, None] = None
232+
se_reduction: Union[int, None] = None
233+
bn_1st: bool = True
234+
zero_bn: bool = True
235+
stem_stride_on: int = 0
236+
stem_sizes: List[int] = [32, 32, 64]
237+
stem_pool: Union[Callable[[], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
238+
stem_bn_end: bool = False
239+
init_cnn: Callable[[nn.Module], None] = init_cnn
240+
make_stem: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
241+
make_layer: Callable[[TModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore
242+
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
243+
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
244+
245+
class Config:
246+
arbitrary_types_allowed = True
247+
extra = "forbid"
248+
249+
def extra_repr(self) -> str:
250+
res = ""
251+
for k, v in self.dict().items():
252+
if v is not None:
253+
res += f"{k}: {v}\n"
254+
return res
255+
256+
def pprint(self) -> None:
257+
print(self.extra_repr())
258+
259+
257260
class ModelConstructor(ModelCfg):
258261
"""Model constructor. As default - xresnet18"""
259262

260263
@root_validator
261-
def post_init(cls, values):
262-
if values["init_cnn"] is None:
263-
values["init_cnn"] = init_cnn
264-
if values["make_stem"] is None:
265-
values["make_stem"] = make_stem
266-
if values["make_layer"] is None:
267-
values["make_layer"] = make_layer
268-
if values["make_body"] is None:
269-
values["make_body"] = make_body
270-
if values["make_head"] is None:
271-
values["make_head"] = make_head
272-
264+
def post_init(cls, values): # pylint: disable=E0213
273265
if values["stem_sizes"][0] != values["in_chans"]:
274266
values["stem_sizes"] = [values["in_chans"]] + values["stem_sizes"]
275267
if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True
@@ -284,15 +276,15 @@ def post_init(cls, values):
284276

285277
@property
286278
def stem(self):
287-
return self.make_stem(self) # type: ignore
279+
return self.make_stem(self) # pylint: disable=too-many-function-args
288280

289281
@property
290282
def head(self):
291-
return self.make_head(self) # type: ignore
283+
return self.make_head(self) # pylint: disable=too-many-function-args
292284

293285
@property
294286
def body(self):
295-
return self.make_body(self) # type: ignore
287+
return self.make_body(self) # pylint: disable=too-many-function-args
296288

297289
@classmethod
298290
def from_cfg(cls, cfg: ModelCfg):
@@ -302,12 +294,12 @@ def __call__(self):
302294
model = nn.Sequential(
303295
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
304296
)
305-
self.init_cnn(model) # type: ignore
297+
self.init_cnn(model) # pylint: disable=too-many-function-args
306298
model.extra_repr = lambda: f"{self.name}"
307299
return model
308300

309301
def __repr__(self):
310-
se_repr = self.se.__name__ if self.se else "False"
302+
se_repr = self.se.__name__ if self.se else "False" # type: ignore
311303
return (
312304
f"{self.name} constructor\n"
313305
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"

src/model_constructor/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3"
1+
__version__ = "0.3.1"

src/model_constructor/xresnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import torch.nn as nn
21
from collections import OrderedDict
32

3+
import torch.nn as nn
4+
45
from .base_constructor import Net
56
from .layers import ConvLayer, Noop, act
67

7-
88
__all__ = ['DownsampleLayer', 'XResBlock', 'xresnet18', 'xresnet34', 'xresnet50']
99

1010

src/model_constructor/yaresnet.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Yet another ResNet.
33

44
from collections import OrderedDict
5-
from typing import Any, Callable, List, Type, Union
5+
from typing import Callable, List, Type, Union
66

77
import torch.nn as nn
88
from torch.nn import Mish
@@ -11,13 +11,12 @@
1111
from .model_constructor import ModelConstructor
1212

1313
__all__ = [
14-
'YaResBlock',
14+
"YaResBlock",
15+
"YaResNet34",
16+
"YaResNet50",
1517
]
1618

1719

18-
# act_fn = nn.ReLU(inplace=True)
19-
20-
2120
class YaResBlock(nn.Module):
2221
'''YaResBlock. Reduce by pool instead of stride 2'''
2322

@@ -129,7 +128,7 @@ class YaResNet34(ModelConstructor):
129128
block: Type[nn.Module] = YaResBlock
130129
expansion: int = 1
131130
layers: List[int] = [3, 4, 6, 3]
132-
act_fn: nn.Module = Mish()
131+
act_fn: Type[nn.Module] = Mish
133132

134133

135134
class YaResNet50(YaResNet34):

tests/parameters.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from torch import nn
2+
3+
4+
def value_name(value) -> str: # pragma: no cover
5+
name = getattr(value, "__name__", None)
6+
if name is not None:
7+
return name
8+
if isinstance(value, nn.Module):
9+
return value._get_name() # pylint: disable=W0212
10+
return value
11+
12+
13+
def ids_fn(key, value):
14+
return [f"{key[:2]}_{value_name(v)}" for v in value]

0 commit comments

Comments
 (0)