Skip to content

Commit ea44d85

Browse files
committed
mc pre_validate instead of __init__
1 parent de2bb6a commit ea44d85

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

src/model_constructor/model_constructor.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,19 @@
33
from typing import Any, Callable, List, Optional, Type, Union
44

55
import torch.nn as nn
6-
from pydantic import BaseModel
6+
from pydantic import BaseModel, root_validator
77

88
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
99

1010
__all__ = [
1111
"init_cnn",
12-
# "act_fn",
1312
"ResBlock",
1413
"ModelConstructor",
1514
"XResNet34",
1615
"XResNet50",
1716
]
1817

1918

20-
# act_fn = nn.ReLU
21-
22-
2319
class ResBlock(nn.Module):
2420
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2521

@@ -261,29 +257,30 @@ def make_head(cfg: ModelCfg) -> nn.Sequential:
261257
class ModelConstructor(ModelCfg):
262258
"""Model constructor. As default - xresnet18"""
263259

264-
def __init__(self, **data):
265-
super().__init__(**data)
266-
if self.init_cnn is None:
267-
self.init_cnn = init_cnn
268-
if self.make_stem is None:
269-
self.make_stem = make_stem
270-
if self.make_layer is None:
271-
self.make_layer = make_layer
272-
if self.make_body is None:
273-
self.make_body = make_body
274-
if self.make_head is None:
275-
self.make_head = make_head
276-
277-
if self.stem_sizes[0] != self.in_chans:
278-
self.stem_sizes = [self.in_chans] + self.stem_sizes
279-
if self.se and isinstance(self.se, (bool, int)): # if se=1 or se=True
280-
self.se = SEModule
281-
if self.sa and isinstance(self.sa, (bool, int)): # if sa=1 or sa=True
282-
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
283-
if self.se_module or self.se_reduction: # pragma: no cover
260+
@root_validator
261+
def post_init(cls, values):
262+
if values["init_cnn"] is None:
263+
values["init_cnn"] = init_cnn
264+
if values["make_stem"] is None:
265+
values["make_stem"] = make_stem
266+
if values["make_layer"] is None:
267+
values["make_layer"] = make_layer
268+
if values["make_body"] is None:
269+
values["make_body"] = make_body
270+
if values["make_head"] is None:
271+
values["make_head"] = make_head
272+
273+
if values["stem_sizes"][0] != values["in_chans"]:
274+
values["stem_sizes"] = [values["in_chans"]] + values["stem_sizes"]
275+
if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True
276+
values["se"] = SEModule
277+
if values["sa"] and isinstance(values["sa"], (bool, int)): # if sa=1 or sa=True
278+
values["sa"] = SimpleSelfAttention # default: ks=1, sym=sym
279+
if values["se_module"] or values["se_reduction"]: # pragma: no cover
284280
print(
285281
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
286282
) # add deprecation warning.
283+
return values
287284

288285
@property
289286
def stem(self):
@@ -310,11 +307,12 @@ def __call__(self):
310307
return model
311308

312309
def __repr__(self):
310+
se_repr = self.se.__name__ if self.se else "False"
313311
return (
314312
f"{self.name} constructor\n"
315313
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
316314
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
317-
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {self.se}\n"
315+
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n"
318316
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
319317
f" body sizes {self.block_sizes}\n"
320318
f" layers: {self.layers}"

tests/test_block.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# import pytest
2+
from functools import partial
23
import torch
34
import torch.nn as nn
45
from model_constructor.layers import SEModule, SimpleSelfAttention
@@ -16,7 +17,7 @@
1617
mid_channels=[8, 16],
1718
stride=[1, 2],
1819
div_groups=[None, 2],
19-
pool=[None, nn.AvgPool2d(2, ceil_mode=True)],
20+
pool=[None, partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)],
2021
se=[None, SEModule],
2122
sa=[None, SimpleSelfAttention],
2223
)

0 commit comments

Comments
 (0)