Skip to content

Commit 2cdd6a0

Browse files
committed
move old models
1 parent 57f9fc2 commit 2cdd6a0

File tree

5 files changed

+102
-83
lines changed

5 files changed

+102
-83
lines changed

src/model_constructor/base_constructor.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
import torch.nn as nn
1+
"""First version of constructor.
2+
"""
3+
# Used in examples.
4+
# first implementation of xresnet - inspired by fastai version.
25
from collections import OrderedDict
3-
from .layers import ConvLayer, Noop, Flatten
6+
from functools import partial
47

8+
import torch.nn as nn
59

6-
__all__ = ['act_fn', 'Stem', 'DownsampleBlock', 'BasicBlock', 'Bottleneck', 'BasicLayer', 'Body', 'Head', 'init_model',
7-
'Net']
10+
from .layers import ConvLayer, Flatten, Noop
11+
12+
__all__ = [
13+
"act_fn",
14+
"Stem",
15+
"DownsampleBlock",
16+
"BasicBlock",
17+
"Bottleneck",
18+
"BasicLayer",
19+
"Body",
20+
"Head",
21+
"init_model",
22+
"Net",
23+
"DownsampleLayer",
24+
"XResBlock",
25+
"xresnet18",
26+
"xresnet34",
27+
"xresnet50",
28+
]
829

930

1031
act_fn = nn.ReLU(inplace=True)
@@ -162,3 +183,58 @@ def __init__(self, stem=Stem,
162183
('head', head(body_out * expansion, num_classes, **kwargs))
163184
]))
164185
self.init_model(self)
186+
187+
188+
# xresnet from fastai
189+
190+
191+
class DownsampleLayer(nn.Sequential):
192+
"""Downsample layer for Xresnet Resblock"""
193+
194+
def __init__(self, conv_layer, ni, nf, stride, act,
195+
pool=nn.AvgPool2d(2, ceil_mode=True), pool_1st=True,
196+
**kwargs):
197+
layers = [] if stride == 1 else [('pool', pool)]
198+
layers += [] if ni == nf else [('idconv', conv_layer(ni, nf, 1, act=act, **kwargs))]
199+
if not pool_1st:
200+
layers.reverse()
201+
super().__init__(OrderedDict(layers))
202+
203+
204+
class XResBlock(nn.Module):
205+
'''XResnet block'''
206+
207+
def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
208+
conv_layer=ConvLayer, act_fn=act_fn, **kwargs):
209+
super().__init__()
210+
nf, ni = nh * expansion, ni * expansion
211+
layers = [('conv_0', conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
212+
('conv_1', conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
213+
] if expansion == 1 else [
214+
('conv_0', conv_layer(ni, nh, 1, act_fn=act_fn, **kwargs)),
215+
('conv_1', conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
216+
('conv_2', conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
217+
]
218+
self.convs = nn.Sequential(OrderedDict(layers))
219+
self.identity = DownsampleLayer(conv_layer, ni, nf, stride,
220+
act=False, act_fn=act_fn, **kwargs) if ni != nf or stride == 2 else Noop()
221+
self.merge = Noop()
222+
self.act_fn = act_fn
223+
224+
def forward(self, x):
225+
return self.act_fn(self.merge(self.convs(x) + self.identity(x)))
226+
227+
228+
def xresnet18(**kwargs):
229+
"""Constructs xresnet18 model. """
230+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs)
231+
232+
233+
def xresnet34(**kwargs):
234+
"""Constructs xresnet34 model. """
235+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs)
236+
237+
238+
def xresnet50(**kwargs):
239+
"""Constructs xresnet50 model. """
240+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs)

src/model_constructor/mxresnet.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +0,0 @@
1-
from functools import partial
2-
3-
from .activations import Mish
4-
from .net import Net
5-
6-
7-
__all__ = ["mxresnet_parameters", "mxresnet34", "mxresnet50"]
8-
9-
10-
mxresnet_parameters = {"stem_sizes": [3, 32, 64, 64], "act_fn": Mish()}
11-
mxresnet34 = partial(
12-
Net, name="MXResnet32", expansion=1, layers=[3, 4, 6, 3], **mxresnet_parameters
13-
)
14-
mxresnet50 = partial(
15-
Net, name="MXResnet50", expansion=4, layers=[3, 4, 6, 3], **mxresnet_parameters
16-
)

src/model_constructor/net.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
from .layers import ConvLayer, Flatten, SEBlock, SimpleSelfAttention, noop
77

88

9-
__all__ = ['init_cnn', 'act_fn', 'ResBlock', 'NewResBlock', 'Net', 'xresnet34', 'xresnet50']
9+
__all__ = [
10+
"init_cnn",
11+
"act_fn",
12+
"ResBlock",
13+
"NewResBlock",
14+
"Net",
15+
"xresnet34",
16+
"xresnet50",
17+
"mxresnet_parameters",
18+
"mxresnet34",
19+
"mxresnet50",
20+
]
1021

1122

1223
act_fn = nn.ReLU(inplace=True)
@@ -209,3 +220,11 @@ def __repr__(self):
209220

210221
xresnet34 = partial(Net, name='xresnet34', expansion=1, layers=[3, 4, 6, 3])
211222
xresnet50 = partial(Net, name='xresnet34', expansion=4, layers=[3, 4, 6, 3])
223+
224+
mxresnet_parameters = {"stem_sizes": [3, 32, 64, 64], "act_fn": nn.Mish()}
225+
mxresnet34 = partial(
226+
Net, name="MXResnet32", expansion=1, layers=[3, 4, 6, 3], **mxresnet_parameters
227+
)
228+
mxresnet50 = partial(
229+
Net, name="MXResnet50", expansion=4, layers=[3, 4, 6, 3], **mxresnet_parameters
230+
)

src/model_constructor/xresnet.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +0,0 @@
1-
from collections import OrderedDict
2-
3-
import torch.nn as nn
4-
5-
from .base_constructor import Net
6-
from .layers import ConvLayer, Noop, act
7-
8-
__all__ = ['DownsampleLayer', 'XResBlock', 'xresnet18', 'xresnet34', 'xresnet50']
9-
10-
11-
class DownsampleLayer(nn.Sequential):
12-
"""Downsample layer for Xresnet Resblock"""
13-
14-
def __init__(self, conv_layer, ni, nf, stride, act,
15-
pool=nn.AvgPool2d(2, ceil_mode=True), pool_1st=True,
16-
**kwargs):
17-
layers = [] if stride == 1 else [('pool', pool)]
18-
layers += [] if ni == nf else [('idconv', conv_layer(ni, nf, 1, act=act, **kwargs))]
19-
if not pool_1st:
20-
layers.reverse()
21-
super().__init__(OrderedDict(layers))
22-
23-
24-
class XResBlock(nn.Module):
25-
'''XResnet block'''
26-
27-
def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
28-
conv_layer=ConvLayer, act_fn=act, **kwargs):
29-
super().__init__()
30-
nf, ni = nh * expansion, ni * expansion
31-
layers = [('conv_0', conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
32-
('conv_1', conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
33-
] if expansion == 1 else [
34-
('conv_0', conv_layer(ni, nh, 1, act_fn=act_fn, **kwargs)),
35-
('conv_1', conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
36-
('conv_2', conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
37-
]
38-
self.convs = nn.Sequential(OrderedDict(layers))
39-
self.identity = DownsampleLayer(conv_layer, ni, nf, stride,
40-
act=False, act_fn=act_fn, **kwargs) if ni != nf or stride == 2 else Noop()
41-
self.merge = Noop()
42-
self.act_fn = act_fn
43-
44-
def forward(self, x):
45-
return self.act_fn(self.merge(self.convs(x) + self.identity(x)))
46-
47-
48-
def xresnet18(**kwargs):
49-
"""Constructs xresnet18 model. """
50-
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs)
51-
52-
53-
def xresnet34(**kwargs):
54-
"""Constructs xresnet34 model. """
55-
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs)
56-
57-
58-
def xresnet50(**kwargs):
59-
"""Constructs xresnet50 model. """
60-
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs)

tests/test_models_old.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import torch
33
from torch import nn
44

5-
from model_constructor.mxresnet import mxresnet34, mxresnet50
6-
from model_constructor.xresnet import xresnet18, xresnet34, xresnet50
5+
from model_constructor.base_constructor import xresnet18, xresnet34, xresnet50
6+
from model_constructor.net import mxresnet34, mxresnet50
77

88
bs_test = 2
99
img_size = 16

0 commit comments

Comments
 (0)