Skip to content

Commit 8f9d6ba

Browse files
authored
Merge pull request #72 from ayasyrev/mc_root_validate
Mc root validate
2 parents 13404b9 + ea44d85 commit 8f9d6ba

File tree

4 files changed

+46
-46
lines changed

4 files changed

+46
-46
lines changed

src/model_constructor/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import List, Optional, Union
2+
from typing import List, Optional, Type, Union
33

44
import torch
55
import torch.nn as nn
@@ -61,7 +61,7 @@ def __init__(
6161
padding: Optional[int] = None,
6262
bias: bool = False,
6363
groups: int = 1,
64-
act_fn: Union[nn.Module, bool] = act_fn,
64+
act_fn: Union[Type[nn.Module], bool] = nn.ReLU,
6565
pre_act: bool = False,
6666
bn_layer: bool = True,
6767
bn_1st: bool = True,
@@ -88,14 +88,14 @@ def __init__(
8888
bn = self.batchnorm_module(out_channels)
8989
nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0)
9090
layers.append(("bn", bn))
91-
if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False
91+
if act_fn: # act_fn either nn.Module subclass or False
9292
if pre_act:
9393
act_position = 0
9494
elif not bn_1st:
9595
act_position = 1
9696
else:
9797
act_position = len(layers)
98-
layers.insert(act_position, ("act_fn", act_fn))
98+
layers.insert(act_position, ("act_fn", act_fn(inplace=True))) # type: ignore
9999
super().__init__(OrderedDict(layers))
100100

101101

src/model_constructor/model_constructor.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
from collections import OrderedDict
2-
from typing import Callable, List, Optional, Type, Union
2+
from functools import partial
3+
from typing import Any, Callable, List, Optional, Type, Union
34

45
import torch.nn as nn
5-
from pydantic import BaseModel
6+
from pydantic import BaseModel, root_validator
67

78
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
89

910
__all__ = [
1011
"init_cnn",
11-
"act_fn",
1212
"ResBlock",
1313
"ModelConstructor",
1414
"XResNet34",
1515
"XResNet50",
1616
]
1717

1818

19-
act_fn = nn.ReLU(inplace=True)
20-
21-
2219
class ResBlock(nn.Module):
2320
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2421

@@ -29,13 +26,13 @@ def __init__(
2926
mid_channels: int,
3027
stride: int = 1,
3128
conv_layer=ConvBnAct,
32-
act_fn: nn.Module = act_fn,
29+
act_fn: Type[nn.Module] = nn.ReLU,
3330
zero_bn: bool = True,
3431
bn_1st: bool = True,
3532
groups: int = 1,
3633
dw: bool = False,
3734
div_groups: Union[None, int] = None,
38-
pool: Union[nn.Module, None] = None,
35+
pool: Union[Callable[[Any], nn.Module], None] = None,
3936
se: Union[nn.Module, None] = None,
4037
sa: Union[nn.Module, None] = None,
4138
):
@@ -100,7 +97,7 @@ def __init__(
10097
if stride != 1 or in_channels != out_channels:
10198
id_layers = []
10299
if stride != 1 and pool is not None: # if pool - reduce by pool else stride 2 art id_conv
103-
id_layers.append(("pool", pool))
100+
id_layers.append(("pool", pool()))
104101
if in_channels != out_channels or (stride != 1 and pool is None):
105102
id_layers += [("id_conv", conv_layer(
106103
in_channels,
@@ -112,7 +109,7 @@ def __init__(
112109
self.id_conv = nn.Sequential(OrderedDict(id_layers))
113110
else:
114111
self.id_conv = None
115-
self.act_fn = act_fn
112+
self.act_fn = act_fn(inplace=True) # type: ignore
116113

117114
def forward(self, x):
118115
identity = self.id_conv(x) if self.id_conv is not None else x
@@ -130,8 +127,8 @@ class ModelCfg(BaseModel):
130127
block_sizes: List[int] = [64, 128, 256, 512]
131128
layers: List[int] = [2, 2, 2, 2]
132129
norm: Type[nn.Module] = nn.BatchNorm2d
133-
act_fn: nn.Module = nn.ReLU(inplace=True)
134-
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
130+
act_fn: Type[nn.Module] = nn.ReLU
131+
pool: Callable[[Any], nn.Module] = partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)
135132
expansion: int = 1
136133
groups: int = 1
137134
dw: bool = False
@@ -144,7 +141,7 @@ class ModelCfg(BaseModel):
144141
zero_bn: bool = True
145142
stem_stride_on: int = 0
146143
stem_sizes: List[int] = [32, 32, 64]
147-
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
144+
stem_pool: Union[Callable[[Any], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
148145
stem_bn_end: bool = False
149146
init_cnn: Optional[Callable[[nn.Module], None]] = None
150147
make_stem: Optional[Callable] = None
@@ -192,7 +189,7 @@ def make_stem(self: ModelCfg) -> nn.Sequential:
192189
for i in range(len(self.stem_sizes) - 1)
193190
]
194191
if self.stem_pool:
195-
stem.append(("stem_pool", self.stem_pool))
192+
stem.append(("stem_pool", self.stem_pool()))
196193
if self.stem_bn_end:
197194
stem.append(("norm", self.norm(self.stem_sizes[-1]))) # type: ignore
198195
return nn.Sequential(OrderedDict(stem))
@@ -260,29 +257,30 @@ def make_head(cfg: ModelCfg) -> nn.Sequential:
260257
class ModelConstructor(ModelCfg):
261258
"""Model constructor. As default - xresnet18"""
262259

263-
def __init__(self, **data):
264-
super().__init__(**data)
265-
if self.init_cnn is None:
266-
self.init_cnn = init_cnn
267-
if self.make_stem is None:
268-
self.make_stem = make_stem
269-
if self.make_layer is None:
270-
self.make_layer = make_layer
271-
if self.make_body is None:
272-
self.make_body = make_body
273-
if self.make_head is None:
274-
self.make_head = make_head
275-
276-
if self.stem_sizes[0] != self.in_chans:
277-
self.stem_sizes = [self.in_chans] + self.stem_sizes
278-
if self.se and isinstance(self.se, (bool, int)): # if se=1 or se=True
279-
self.se = SEModule
280-
if self.sa and isinstance(self.sa, (bool, int)): # if sa=1 or sa=True
281-
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
282-
if self.se_module or self.se_reduction: # pragma: no cover
260+
@root_validator
261+
def post_init(cls, values):
262+
if values["init_cnn"] is None:
263+
values["init_cnn"] = init_cnn
264+
if values["make_stem"] is None:
265+
values["make_stem"] = make_stem
266+
if values["make_layer"] is None:
267+
values["make_layer"] = make_layer
268+
if values["make_body"] is None:
269+
values["make_body"] = make_body
270+
if values["make_head"] is None:
271+
values["make_head"] = make_head
272+
273+
if values["stem_sizes"][0] != values["in_chans"]:
274+
values["stem_sizes"] = [values["in_chans"]] + values["stem_sizes"]
275+
if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True
276+
values["se"] = SEModule
277+
if values["sa"] and isinstance(values["sa"], (bool, int)): # if sa=1 or sa=True
278+
values["sa"] = SimpleSelfAttention # default: ks=1, sym=sym
279+
if values["se_module"] or values["se_reduction"]: # pragma: no cover
283280
print(
284281
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
285282
) # add deprecation warning.
283+
return values
286284

287285
@property
288286
def stem(self):
@@ -309,11 +307,12 @@ def __call__(self):
309307
return model
310308

311309
def __repr__(self):
310+
se_repr = self.se.__name__ if self.se else "False"
312311
return (
313312
f"{self.name} constructor\n"
314313
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
315314
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
316-
f" sa: {self.sa}, se: {self.se}\n"
315+
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n"
317316
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
318317
f" body sizes {self.block_sizes}\n"
319318
f" layers: {self.layers}"

src/model_constructor/yaresnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Yet another ResNet.
33

44
from collections import OrderedDict
5-
from typing import List, Type, Union
5+
from typing import Any, Callable, List, Type, Union
66

77
import torch.nn as nn
88
from torch.nn import Mish
@@ -15,7 +15,7 @@
1515
]
1616

1717

18-
act_fn = nn.ReLU(inplace=True)
18+
# act_fn = nn.ReLU(inplace=True)
1919

2020

2121
class YaResBlock(nn.Module):
@@ -28,13 +28,13 @@ def __init__(
2828
mid_channels: int,
2929
stride: int = 1,
3030
conv_layer=ConvBnAct,
31-
act_fn: nn.Module = act_fn,
31+
act_fn: Type[nn.Module] = nn.ReLU,
3232
zero_bn: bool = True,
3333
bn_1st: bool = True,
3434
groups: int = 1,
3535
dw: bool = False,
3636
div_groups: Union[None, int] = None,
37-
pool: Union[nn.Module, None] = None,
37+
pool: Union[Callable[[Any], nn.Module], None] = None,
3838
se: Union[nn.Module, None] = None,
3939
sa: Union[nn.Module, None] = None,
4040
):
@@ -49,7 +49,7 @@ def __init__(
4949
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
5050
# warnings.warn("pool not passed") # need to warn?
5151
else:
52-
self.reduce = pool
52+
self.reduce = pool()
5353
else:
5454
self.reduce = None
5555
if expansion == 1:
@@ -115,7 +115,7 @@ def __init__(
115115
)
116116
else:
117117
self.id_conv = None
118-
self.merge = act_fn
118+
self.merge = act_fn()
119119

120120
def forward(self, x):
121121
if self.reduce:

tests/test_block.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# import pytest
2+
from functools import partial
23
import torch
34
import torch.nn as nn
45
from model_constructor.layers import SEModule, SimpleSelfAttention
@@ -16,7 +17,7 @@
1617
mid_channels=[8, 16],
1718
stride=[1, 2],
1819
div_groups=[None, 2],
19-
pool=[None, nn.AvgPool2d(2, ceil_mode=True)],
20+
pool=[None, partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)],
2021
se=[None, SEModule],
2122
sa=[None, SimpleSelfAttention],
2223
)

0 commit comments

Comments
 (0)