Skip to content

Commit d620b81

Browse files
committed
refactor model cfg
1 parent 5e4d304 commit d620b81

File tree

2 files changed

+109
-100
lines changed

2 files changed

+109
-100
lines changed

src/model_constructor/helpers.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,69 @@
11
from collections import OrderedDict
2-
from typing import Iterable
2+
from functools import partial
3+
from typing import Iterable, Optional
4+
from pydantic import BaseModel
35

46
from torch import nn
57

68

79
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
810
"""return nn.Sequential from OrderedDict from list of tuples"""
911
return nn.Sequential(OrderedDict(list_of_tuples)) #
12+
13+
14+
def init_cnn(module: nn.Module) -> None:
15+
"Init module - kaiming_normal for Conv2d and 0 for biases."
16+
if getattr(module, "bias", None) is not None:
17+
nn.init.constant_(module.bias, 0) # type: ignore
18+
if isinstance(module, (nn.Conv2d, nn.Linear)):
19+
nn.init.kaiming_normal_(module.weight)
20+
for layer in module.children():
21+
init_cnn(layer)
22+
23+
24+
class Cfg(BaseModel):
25+
"""Base class for config."""
26+
27+
name: Optional[str] = None
28+
29+
def _get_str_value(self, field: str) -> str:
30+
value = getattr(self, field)
31+
if isinstance(value, type):
32+
value = value.__name__
33+
elif isinstance(value, partial):
34+
value = f"{value.func.__name__} {value.keywords}"
35+
elif callable(value):
36+
value = value.__name__
37+
return value
38+
39+
def __repr__(self) -> str:
40+
return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})"
41+
42+
def __repr_args__(self) -> list[tuple[str, str]]:
43+
return [
44+
(field, str_value)
45+
for field in self.model_fields
46+
if (str_value := self._get_str_value(field))
47+
]
48+
49+
def __repr_changed_args__(self) -> list[str]:
50+
"""Return list repr for changed fields"""
51+
return [
52+
f"{field}: {self._get_str_value(field)}"
53+
for field in self.model_fields_set
54+
if field != "name"
55+
]
56+
57+
def print_cfg(self) -> None:
58+
"""Print full config"""
59+
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
60+
61+
def print_changed(self) -> None:
62+
"""Print changed fields."""
63+
changed_fields = self.__repr_changed_args__()
64+
if changed_fields:
65+
print("Changed fields:")
66+
for i in changed_fields:
67+
print(" ", i)
68+
else:
69+
print("Nothing changed")

src/model_constructor/model_constructor.py

Lines changed: 48 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch
6-
from pydantic import BaseModel, field_validator
6+
from pydantic import field_validator
77
from torch import nn
88

9-
from .helpers import nn_seq
9+
from .helpers import nn_seq, Cfg, init_cnn
1010
from .layers import ConvBnAct, SEModule, SimpleSelfAttention, get_act
1111

1212
__all__ = [
@@ -23,16 +23,6 @@
2323
ListStrMod = list[tuple[str, nn.Module]]
2424

2525

26-
def init_cnn(module: nn.Module) -> None:
27-
"Init module - kaiming_normal for Conv2d and 0 for biases."
28-
if getattr(module, "bias", None) is not None:
29-
nn.init.constant_(module.bias, 0) # type: ignore
30-
if isinstance(module, (nn.Conv2d, nn.Linear)):
31-
nn.init.kaiming_normal_(module.weight)
32-
for layer in module.children():
33-
init_cnn(layer)
34-
35-
3626
class BasicBlock(nn.Module):
3727
"""Basic Resnet block.
3828
Configurable - can use pool to reduce at identity path, change act etc."""
@@ -212,6 +202,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
212202
return self.act_fn(self.convs(x) + identity)
213203

214204

205+
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
206+
"""Model constructor Config. As default - xresnet18"""
207+
208+
name: Optional[str] = None
209+
in_chans: int = 3
210+
num_classes: int = 1000
211+
block: type[nn.Module] = BasicBlock
212+
conv_layer: type[nn.Module] = ConvBnAct
213+
block_sizes: list[int] = [64, 128, 256, 512]
214+
layers: list[int] = [2, 2, 2, 2]
215+
norm: type[nn.Module] = nn.BatchNorm2d
216+
act_fn: type[nn.Module] = nn.ReLU
217+
pool: Optional[Callable[[Any], nn.Module]] = None
218+
expansion: int = 1
219+
groups: int = 1
220+
dw: bool = False
221+
div_groups: Union[int, None] = None
222+
sa: Union[bool, type[nn.Module]] = False
223+
se: Union[bool, type[nn.Module]] = False
224+
se_module: Union[bool, None] = None
225+
se_reduction: Union[int, None] = None
226+
bn_1st: bool = True
227+
zero_bn: bool = True
228+
stem_stride_on: int = 0
229+
stem_sizes: list[int] = [64]
230+
stem_pool: Union[Callable[[], nn.Module], None] = partial(
231+
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
232+
)
233+
stem_bn_end: bool = False
234+
235+
def __repr__(self) -> str:
236+
se_repr = self.se.__name__ if self.se else "False" # type: ignore
237+
model_name = self.name or self.__class__.__name__
238+
return (
239+
f"{model_name}\n"
240+
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
241+
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
242+
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n"
243+
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
244+
f" body sizes {self.block_sizes}\n"
245+
f" layers: {self.layers}"
246+
)
247+
248+
215249
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
216250
"""Create Resnet stem."""
217251
stem: ListStrMod = [
@@ -285,87 +319,15 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
285319
return nn_seq(head)
286320

287321

288-
class ModelCfg(BaseModel, arbitrary_types_allowed=True, extra="forbid"):
289-
"""Model constructor Config. As default - xresnet18"""
322+
class ModelConstructor(ModelCfg):
323+
"""Model constructor. As default - resnet18"""
290324

291-
name: Optional[str] = None
292-
in_chans: int = 3
293-
num_classes: int = 1000
294-
block: type[nn.Module] = BasicBlock
295-
conv_layer: type[nn.Module] = ConvBnAct
296-
block_sizes: list[int] = [64, 128, 256, 512]
297-
layers: list[int] = [2, 2, 2, 2]
298-
norm: type[nn.Module] = nn.BatchNorm2d
299-
act_fn: type[nn.Module] = nn.ReLU
300-
pool: Optional[Callable[[Any], nn.Module]] = None
301-
expansion: int = 1
302-
groups: int = 1
303-
dw: bool = False
304-
div_groups: Union[int, None] = None
305-
sa: Union[bool, type[nn.Module]] = False
306-
se: Union[bool, type[nn.Module]] = False
307-
se_module: Union[bool, None] = None
308-
se_reduction: Union[int, None] = None
309-
bn_1st: bool = True
310-
zero_bn: bool = True
311-
stem_stride_on: int = 0
312-
stem_sizes: list[int] = [64]
313-
stem_pool: Union[Callable[[], nn.Module], None] = partial(
314-
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
315-
)
316-
stem_bn_end: bool = False
317325
init_cnn: Callable[[nn.Module], None] = init_cnn
318326
make_stem: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
319327
make_layer: Callable[[TModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore
320328
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
321329
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
322330

323-
def _get_str_value(self, field: str) -> str:
324-
value = getattr(self, field)
325-
if isinstance(value, type):
326-
value = value.__name__
327-
elif isinstance(value, partial):
328-
value = f"{value.func.__name__} {value.keywords}"
329-
elif callable(value):
330-
value = value.__name__
331-
return value
332-
333-
def __repr__(self) -> str:
334-
return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})"
335-
336-
def __repr_args__(self) -> list[tuple[str, str]]:
337-
return [
338-
(field, str_value)
339-
for field in self.model_fields
340-
if (str_value := self._get_str_value(field))
341-
]
342-
343-
def __repr_changed_args__(self) -> list[str]:
344-
"""Return list repr for changed fields"""
345-
return [
346-
f"{field}: {self._get_str_value(field)}"
347-
for field in self.model_fields_set
348-
if field != "name"
349-
]
350-
351-
def print_cfg(self) -> None:
352-
"""Print full config"""
353-
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
354-
355-
def print_changed(self) -> None:
356-
"""Print changed fields."""
357-
changed_fields = self.__repr_changed_args__()
358-
if changed_fields:
359-
print("Changed fields:")
360-
for i in changed_fields:
361-
print(" ", i)
362-
else:
363-
print("Nothing changed")
364-
365-
366-
class ModelConstructor(ModelCfg):
367-
"""Model constructor. As default - resnet18"""
368-
369331
@field_validator("se")
370332
def set_se( # pylint: disable=no-self-argument
371333
cls, value: Union[bool, type[nn.Module]]
@@ -430,19 +392,6 @@ def __call__(self) -> nn.Sequential:
430392
model.extra_repr = lambda: ", ".join(extra_repr)
431393
return model
432394

433-
def __repr__(self) -> str:
434-
se_repr = self.se.__name__ if self.se else "False" # type: ignore
435-
model_name = self.name or self.__class__.__name__
436-
return (
437-
f"{model_name}\n"
438-
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
439-
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
440-
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n"
441-
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
442-
f" body sizes {self.block_sizes}\n"
443-
f" layers: {self.layers}"
444-
)
445-
446395

447396
class ResNet34(ModelConstructor):
448397
layers: list[int] = [3, 4, 6, 3]

0 commit comments

Comments
 (0)