|
1 | | -import torch.nn as nn |
| 1 | +"""First version of constructor. |
| 2 | +""" |
| 3 | +# Used in examples. |
| 4 | +# first implementation of xresnet - inspired by fastai version. |
2 | 5 | from collections import OrderedDict |
3 | | -from .layers import ConvLayer, Noop, Flatten |
| 6 | +from functools import partial |
4 | 7 |
|
| 8 | +import torch.nn as nn |
5 | 9 |
|
6 | | -__all__ = ['act_fn', 'Stem', 'DownsampleBlock', 'BasicBlock', 'Bottleneck', 'BasicLayer', 'Body', 'Head', 'init_model', |
7 | | - 'Net'] |
| 10 | +from .layers import ConvLayer, Flatten, Noop |
| 11 | + |
| 12 | +__all__ = [ |
| 13 | + "act_fn", |
| 14 | + "Stem", |
| 15 | + "DownsampleBlock", |
| 16 | + "BasicBlock", |
| 17 | + "Bottleneck", |
| 18 | + "BasicLayer", |
| 19 | + "Body", |
| 20 | + "Head", |
| 21 | + "init_model", |
| 22 | + "Net", |
| 23 | + "DownsampleLayer", |
| 24 | + "XResBlock", |
| 25 | + "xresnet18", |
| 26 | + "xresnet34", |
| 27 | + "xresnet50", |
| 28 | +] |
8 | 29 |
|
9 | 30 |
|
10 | 31 | act_fn = nn.ReLU(inplace=True) |
@@ -162,3 +183,58 @@ def __init__(self, stem=Stem, |
162 | 183 | ('head', head(body_out * expansion, num_classes, **kwargs)) |
163 | 184 | ])) |
164 | 185 | self.init_model(self) |
| 186 | + |
| 187 | + |
| 188 | +# xresnet from fastai |
| 189 | + |
| 190 | + |
| 191 | +class DownsampleLayer(nn.Sequential): |
| 192 | + """Downsample layer for Xresnet Resblock""" |
| 193 | + |
| 194 | + def __init__(self, conv_layer, ni, nf, stride, act, |
| 195 | + pool=nn.AvgPool2d(2, ceil_mode=True), pool_1st=True, |
| 196 | + **kwargs): |
| 197 | + layers = [] if stride == 1 else [('pool', pool)] |
| 198 | + layers += [] if ni == nf else [('idconv', conv_layer(ni, nf, 1, act=act, **kwargs))] |
| 199 | + if not pool_1st: |
| 200 | + layers.reverse() |
| 201 | + super().__init__(OrderedDict(layers)) |
| 202 | + |
| 203 | + |
| 204 | +class XResBlock(nn.Module): |
| 205 | + '''XResnet block''' |
| 206 | + |
| 207 | + def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True, |
| 208 | + conv_layer=ConvLayer, act_fn=act_fn, **kwargs): |
| 209 | + super().__init__() |
| 210 | + nf, ni = nh * expansion, ni * expansion |
| 211 | + layers = [('conv_0', conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, **kwargs)), |
| 212 | + ('conv_1', conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs)) |
| 213 | + ] if expansion == 1 else [ |
| 214 | + ('conv_0', conv_layer(ni, nh, 1, act_fn=act_fn, **kwargs)), |
| 215 | + ('conv_1', conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, **kwargs)), |
| 216 | + ('conv_2', conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs)) |
| 217 | + ] |
| 218 | + self.convs = nn.Sequential(OrderedDict(layers)) |
| 219 | + self.identity = DownsampleLayer(conv_layer, ni, nf, stride, |
| 220 | + act=False, act_fn=act_fn, **kwargs) if ni != nf or stride == 2 else Noop() |
| 221 | + self.merge = Noop() |
| 222 | + self.act_fn = act_fn |
| 223 | + |
| 224 | + def forward(self, x): |
| 225 | + return self.act_fn(self.merge(self.convs(x) + self.identity(x))) |
| 226 | + |
| 227 | + |
| 228 | +def xresnet18(**kwargs): |
| 229 | + """Constructs xresnet18 model. """ |
| 230 | + return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs) |
| 231 | + |
| 232 | + |
| 233 | +def xresnet34(**kwargs): |
| 234 | + """Constructs xresnet34 model. """ |
| 235 | + return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs) |
| 236 | + |
| 237 | + |
| 238 | +def xresnet50(**kwargs): |
| 239 | + """Constructs xresnet50 model. """ |
| 240 | + return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs) |
0 commit comments