Skip to content

Commit 90991a0

Browse files
committed
rename Net to ModelConstructor
1 parent 0bf76ae commit 90991a0

File tree

3 files changed

+179
-5
lines changed

3 files changed

+179
-5
lines changed

model_constructor/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
__version__ = "0.1.7"
1+
from model_constructor.model_constructor import ModelConstructor # noqa F401
2+
3+
4+
__version__ = "0.1.8"
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from collections import OrderedDict
2+
from functools import partial
3+
4+
import torch.nn as nn
5+
6+
from .layers import ConvLayer, Flatten, SEBlock, SimpleSelfAttention, noop
7+
8+
9+
__all__ = ['init_cnn', 'act_fn', 'ResBlock', 'ModelConstructor', 'xresnet34', 'xresnet50']
10+
11+
12+
act_fn = nn.ReLU(inplace=True)
13+
14+
15+
def init_cnn(module: nn.Module):
16+
"Init module - kaiming_normal for Conv2d and 0 for biases."
17+
if getattr(module, 'bias', None) is not None:
18+
nn.init.constant_(module.bias, 0)
19+
if isinstance(module, (nn.Conv2d, nn.Linear)):
20+
nn.init.kaiming_normal_(module.weight)
21+
for layer in module.children():
22+
init_cnn(layer)
23+
24+
25+
class ResBlock(nn.Module):
26+
'''Resnet block'''
27+
se_block = SEBlock
28+
29+
def __init__(self, expansion, ni, nh, stride=1,
30+
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
31+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False, se_reduction=16,
32+
groups=1, dw=False, div_groups=None):
33+
super().__init__()
34+
nf, ni = nh * expansion, ni * expansion
35+
if div_groups is not None: # check if grops != 1 and div_groups
36+
groups = int(nh / div_groups)
37+
if expansion == 1:
38+
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
39+
groups=nh if dw else groups)),
40+
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
41+
]
42+
else:
43+
layers = [("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
44+
("conv_1", conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
45+
groups=nh if dw else groups)),
46+
("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
47+
]
48+
if se:
49+
layers.append(('se', self.se_block(nf, se_reduction)))
50+
if sa:
51+
layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
52+
self.convs = nn.Sequential(OrderedDict(layers))
53+
self.pool = noop if stride == 1 else pool
54+
self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False)
55+
self.act_fn = act_fn
56+
57+
def forward(self, x):
58+
return self.act_fn(self.convs(x) + self.idconv(self.pool(x)))
59+
60+
61+
def _make_stem(self):
62+
stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i + 1],
63+
stride=2 if i == self.stem_stride_on else 1,
64+
bn_layer=(not self.stem_bn_end) if i == (len(self.stem_sizes) - 2) else True,
65+
act_fn=self.act_fn, bn_1st=self.bn_1st))
66+
for i in range(len(self.stem_sizes) - 1)]
67+
stem.append(('stem_pool', self.stem_pool))
68+
if self.stem_bn_end:
69+
stem.append(('norm', self.norm(self.stem_sizes[-1])))
70+
return nn.Sequential(OrderedDict(stem))
71+
72+
73+
def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
74+
layers = [(f"bl_{i}", self.block(expansion, ni if i == 0 else nf, nf,
75+
stride if i == 0 else 1, sa=sa if i == blocks - 1 else False,
76+
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
77+
zero_bn=self.zero_bn, bn_1st=self.bn_1st,
78+
groups=self.groups, div_groups=self.div_groups,
79+
dw=self.dw, se=self.se))
80+
for i in range(blocks)]
81+
return nn.Sequential(OrderedDict(layers))
82+
83+
84+
def _make_body(self):
85+
blocks = [(f"l_{i}", self._make_layer(self, self.expansion,
86+
ni=self.block_sizes[i], nf=self.block_sizes[i + 1],
87+
blocks=l, stride=1 if i == 0 else 2,
88+
sa=self.sa if i == 0 else False))
89+
for i, l in enumerate(self.layers)]
90+
return nn.Sequential(OrderedDict(blocks))
91+
92+
93+
def _make_head(self):
94+
head = [('pool', nn.AdaptiveAvgPool2d(1)),
95+
('flat', Flatten()),
96+
('fc', nn.Linear(self.block_sizes[-1] * self.expansion, self.c_out))]
97+
return nn.Sequential(OrderedDict(head))
98+
99+
100+
class ModelConstructor():
101+
"""Model constructor. As default - xresnet18"""
102+
def __init__(self, name='MC', c_in=3, c_out=1000,
103+
block=ResBlock, conv_layer=ConvLayer,
104+
block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2],
105+
norm=nn.BatchNorm2d,
106+
act_fn=nn.ReLU(inplace=True),
107+
pool=nn.AvgPool2d(2, ceil_mode=True),
108+
expansion=1, groups=1, dw=False, div_groups=None,
109+
sa=False, se=False, se_reduction=16,
110+
bn_1st=True,
111+
zero_bn=True,
112+
stem_stride_on=0,
113+
stem_sizes=[32, 32, 64],
114+
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
115+
stem_bn_end=False,
116+
_init_cnn=init_cnn,
117+
_make_stem=_make_stem,
118+
_make_layer=_make_layer,
119+
_make_body=_make_body,
120+
_make_head=_make_head,
121+
):
122+
super().__init__()
123+
124+
params = locals()
125+
del params['self']
126+
self.__dict__ = params
127+
self._block_sizes = params['block_sizes']
128+
if self.stem_sizes[0] != self.c_in:
129+
self.stem_sizes = [self.c_in] + self.stem_sizes
130+
131+
@property
132+
def block_sizes(self):
133+
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes + [256] * (len(self.layers) - 4)
134+
135+
@property
136+
def stem(self):
137+
return self._make_stem(self)
138+
139+
@property
140+
def head(self):
141+
return self._make_head(self)
142+
143+
@property
144+
def body(self):
145+
return self._make_body(self)
146+
147+
def __call__(self):
148+
model = nn.Sequential(OrderedDict([
149+
('stem', self.stem),
150+
('body', self.body),
151+
('head', self.head)]))
152+
self._init_cnn(model)
153+
model.extra_repr = lambda: f"model {self.name}"
154+
return model
155+
156+
def __repr__(self):
157+
return (f"{self.name} constructor\n"
158+
f" c_in: {self.c_in}, c_out: {self.c_out}\n"
159+
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
160+
f" sa: {self.sa}, se: {self.se}\n"
161+
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
162+
f" body sizes {self._block_sizes}\n"
163+
f" layers: {self.layers}")
164+
165+
166+
xresnet34 = partial(ModelConstructor, name='xresnet34', expansion=1, layers=[3, 4, 6, 3])
167+
xresnet50 = partial(ModelConstructor, name='xresnet34', expansion=4, layers=[3, 4, 6, 3])

model_constructor/net.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def forward(self, x):
6161
# NewResBlock now is YaResBlock - Yet Another ResNet Block! It is now at model_constructor.yaresnet.
6262

6363

64-
class NewResBlock(nn.Module):
65-
'''YaResnet block'''
64+
class NewResBlock(nn.Module): # todo: deprecation worning.
65+
'''YaResnet block.
66+
This is first impl, deprecated, use yaresnet module.
67+
'''
6668
se_block = SEBlock
6769

6870
def __init__(self, expansion, ni, nh, stride=1,
@@ -137,8 +139,10 @@ def _make_head(self):
137139
return nn.Sequential(OrderedDict(head))
138140

139141

140-
class Net():
141-
"""Model constructor. As default - xresnet18"""
142+
class Net(): # todo: deprecation worning.
143+
"""Model constructor. As default - xresnet18.
144+
First version, still here for compatibility. Use ModelConstructor instead.
145+
"""
142146
def __init__(self, name='Net', c_in=3, c_out=1000,
143147
block=ResBlock, conv_layer=ConvLayer,
144148
block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2],

0 commit comments

Comments
 (0)