|
| 1 | +from collections import OrderedDict |
| 2 | +from functools import partial |
| 3 | + |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | +from .layers import ConvLayer, Flatten, SEBlock, SimpleSelfAttention, noop |
| 7 | + |
| 8 | + |
| 9 | +__all__ = ['init_cnn', 'act_fn', 'ResBlock', 'ModelConstructor', 'xresnet34', 'xresnet50'] |
| 10 | + |
| 11 | + |
| 12 | +act_fn = nn.ReLU(inplace=True) |
| 13 | + |
| 14 | + |
| 15 | +def init_cnn(module: nn.Module): |
| 16 | + "Init module - kaiming_normal for Conv2d and 0 for biases." |
| 17 | + if getattr(module, 'bias', None) is not None: |
| 18 | + nn.init.constant_(module.bias, 0) |
| 19 | + if isinstance(module, (nn.Conv2d, nn.Linear)): |
| 20 | + nn.init.kaiming_normal_(module.weight) |
| 21 | + for layer in module.children(): |
| 22 | + init_cnn(layer) |
| 23 | + |
| 24 | + |
| 25 | +class ResBlock(nn.Module): |
| 26 | + '''Resnet block''' |
| 27 | + se_block = SEBlock |
| 28 | + |
| 29 | + def __init__(self, expansion, ni, nh, stride=1, |
| 30 | + conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True, |
| 31 | + pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False, se_reduction=16, |
| 32 | + groups=1, dw=False, div_groups=None): |
| 33 | + super().__init__() |
| 34 | + nf, ni = nh * expansion, ni * expansion |
| 35 | + if div_groups is not None: # check if grops != 1 and div_groups |
| 36 | + groups = int(nh / div_groups) |
| 37 | + if expansion == 1: |
| 38 | + layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st, |
| 39 | + groups=nh if dw else groups)), |
| 40 | + ("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st)) |
| 41 | + ] |
| 42 | + else: |
| 43 | + layers = [("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)), |
| 44 | + ("conv_1", conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st, |
| 45 | + groups=nh if dw else groups)), |
| 46 | + ("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st)) |
| 47 | + ] |
| 48 | + if se: |
| 49 | + layers.append(('se', self.se_block(nf, se_reduction))) |
| 50 | + if sa: |
| 51 | + layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym))) |
| 52 | + self.convs = nn.Sequential(OrderedDict(layers)) |
| 53 | + self.pool = noop if stride == 1 else pool |
| 54 | + self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False) |
| 55 | + self.act_fn = act_fn |
| 56 | + |
| 57 | + def forward(self, x): |
| 58 | + return self.act_fn(self.convs(x) + self.idconv(self.pool(x))) |
| 59 | + |
| 60 | + |
| 61 | +def _make_stem(self): |
| 62 | + stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i + 1], |
| 63 | + stride=2 if i == self.stem_stride_on else 1, |
| 64 | + bn_layer=(not self.stem_bn_end) if i == (len(self.stem_sizes) - 2) else True, |
| 65 | + act_fn=self.act_fn, bn_1st=self.bn_1st)) |
| 66 | + for i in range(len(self.stem_sizes) - 1)] |
| 67 | + stem.append(('stem_pool', self.stem_pool)) |
| 68 | + if self.stem_bn_end: |
| 69 | + stem.append(('norm', self.norm(self.stem_sizes[-1]))) |
| 70 | + return nn.Sequential(OrderedDict(stem)) |
| 71 | + |
| 72 | + |
| 73 | +def _make_layer(self, expansion, ni, nf, blocks, stride, sa): |
| 74 | + layers = [(f"bl_{i}", self.block(expansion, ni if i == 0 else nf, nf, |
| 75 | + stride if i == 0 else 1, sa=sa if i == blocks - 1 else False, |
| 76 | + conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool, |
| 77 | + zero_bn=self.zero_bn, bn_1st=self.bn_1st, |
| 78 | + groups=self.groups, div_groups=self.div_groups, |
| 79 | + dw=self.dw, se=self.se)) |
| 80 | + for i in range(blocks)] |
| 81 | + return nn.Sequential(OrderedDict(layers)) |
| 82 | + |
| 83 | + |
| 84 | +def _make_body(self): |
| 85 | + blocks = [(f"l_{i}", self._make_layer(self, self.expansion, |
| 86 | + ni=self.block_sizes[i], nf=self.block_sizes[i + 1], |
| 87 | + blocks=l, stride=1 if i == 0 else 2, |
| 88 | + sa=self.sa if i == 0 else False)) |
| 89 | + for i, l in enumerate(self.layers)] |
| 90 | + return nn.Sequential(OrderedDict(blocks)) |
| 91 | + |
| 92 | + |
| 93 | +def _make_head(self): |
| 94 | + head = [('pool', nn.AdaptiveAvgPool2d(1)), |
| 95 | + ('flat', Flatten()), |
| 96 | + ('fc', nn.Linear(self.block_sizes[-1] * self.expansion, self.c_out))] |
| 97 | + return nn.Sequential(OrderedDict(head)) |
| 98 | + |
| 99 | + |
| 100 | +class ModelConstructor(): |
| 101 | + """Model constructor. As default - xresnet18""" |
| 102 | + def __init__(self, name='MC', c_in=3, c_out=1000, |
| 103 | + block=ResBlock, conv_layer=ConvLayer, |
| 104 | + block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2], |
| 105 | + norm=nn.BatchNorm2d, |
| 106 | + act_fn=nn.ReLU(inplace=True), |
| 107 | + pool=nn.AvgPool2d(2, ceil_mode=True), |
| 108 | + expansion=1, groups=1, dw=False, div_groups=None, |
| 109 | + sa=False, se=False, se_reduction=16, |
| 110 | + bn_1st=True, |
| 111 | + zero_bn=True, |
| 112 | + stem_stride_on=0, |
| 113 | + stem_sizes=[32, 32, 64], |
| 114 | + stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
| 115 | + stem_bn_end=False, |
| 116 | + _init_cnn=init_cnn, |
| 117 | + _make_stem=_make_stem, |
| 118 | + _make_layer=_make_layer, |
| 119 | + _make_body=_make_body, |
| 120 | + _make_head=_make_head, |
| 121 | + ): |
| 122 | + super().__init__() |
| 123 | + |
| 124 | + params = locals() |
| 125 | + del params['self'] |
| 126 | + self.__dict__ = params |
| 127 | + self._block_sizes = params['block_sizes'] |
| 128 | + if self.stem_sizes[0] != self.c_in: |
| 129 | + self.stem_sizes = [self.c_in] + self.stem_sizes |
| 130 | + |
| 131 | + @property |
| 132 | + def block_sizes(self): |
| 133 | + return [self.stem_sizes[-1] // self.expansion] + self._block_sizes + [256] * (len(self.layers) - 4) |
| 134 | + |
| 135 | + @property |
| 136 | + def stem(self): |
| 137 | + return self._make_stem(self) |
| 138 | + |
| 139 | + @property |
| 140 | + def head(self): |
| 141 | + return self._make_head(self) |
| 142 | + |
| 143 | + @property |
| 144 | + def body(self): |
| 145 | + return self._make_body(self) |
| 146 | + |
| 147 | + def __call__(self): |
| 148 | + model = nn.Sequential(OrderedDict([ |
| 149 | + ('stem', self.stem), |
| 150 | + ('body', self.body), |
| 151 | + ('head', self.head)])) |
| 152 | + self._init_cnn(model) |
| 153 | + model.extra_repr = lambda: f"model {self.name}" |
| 154 | + return model |
| 155 | + |
| 156 | + def __repr__(self): |
| 157 | + return (f"{self.name} constructor\n" |
| 158 | + f" c_in: {self.c_in}, c_out: {self.c_out}\n" |
| 159 | + f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n" |
| 160 | + f" sa: {self.sa}, se: {self.se}\n" |
| 161 | + f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n" |
| 162 | + f" body sizes {self._block_sizes}\n" |
| 163 | + f" layers: {self.layers}") |
| 164 | + |
| 165 | + |
| 166 | +xresnet34 = partial(ModelConstructor, name='xresnet34', expansion=1, layers=[3, 4, 6, 3]) |
| 167 | +xresnet50 = partial(ModelConstructor, name='xresnet34', expansion=4, layers=[3, 4, 6, 3]) |
0 commit comments