|
14 | 14 | import torch.nn.functional as F |
15 | 15 |
|
16 | 16 | from .helpers import build_model_with_cfg |
17 | | -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead |
| 17 | +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule |
18 | 18 | from .registry import register_model |
19 | 19 |
|
20 | 20 | __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] |
@@ -49,40 +49,6 @@ def _cfg(url='', **kwargs): |
49 | 49 | } |
50 | 50 |
|
51 | 51 |
|
52 | | -class FastGlobalAvgPool2d(nn.Module): |
53 | | - def __init__(self, flatten=False): |
54 | | - super(FastGlobalAvgPool2d, self).__init__() |
55 | | - self.flatten = flatten |
56 | | - |
57 | | - def forward(self, x): |
58 | | - if self.flatten: |
59 | | - return x.mean((2, 3)) |
60 | | - else: |
61 | | - return x.mean((2, 3), keepdim=True) |
62 | | - |
63 | | - def feat_mult(self): |
64 | | - return 1 |
65 | | - |
66 | | - |
67 | | -class FastSEModule(nn.Module): |
68 | | - |
69 | | - def __init__(self, channels, reduction_channels, inplace=True): |
70 | | - super(FastSEModule, self).__init__() |
71 | | - self.avg_pool = FastGlobalAvgPool2d() |
72 | | - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True) |
73 | | - self.relu = nn.ReLU(inplace=inplace) |
74 | | - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True) |
75 | | - self.activation = nn.Sigmoid() |
76 | | - |
77 | | - def forward(self, x): |
78 | | - x_se = self.avg_pool(x) |
79 | | - x_se2 = self.fc1(x_se) |
80 | | - x_se2 = self.relu(x_se2) |
81 | | - x_se = self.fc2(x_se2) |
82 | | - x_se = self.activation(x_se) |
83 | | - return x * x_se |
84 | | - |
85 | | - |
86 | 52 | def IABN2Float(module: nn.Module) -> nn.Module: |
87 | 53 | """If `module` is IABN don't use half precision.""" |
88 | 54 | if isinstance(module, InplaceAbn): |
@@ -119,8 +85,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_ |
119 | 85 | self.relu = nn.ReLU(inplace=True) |
120 | 86 | self.downsample = downsample |
121 | 87 | self.stride = stride |
122 | | - reduce_layer_planes = max(planes * self.expansion // 4, 64) |
123 | | - self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None |
| 88 | + reduction_chs = max(planes * self.expansion // 4, 64) |
| 89 | + self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None |
124 | 90 |
|
125 | 91 | def forward(self, x): |
126 | 92 | if self.downsample is not None: |
@@ -159,8 +125,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, |
159 | 125 | conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), |
160 | 126 | aa_layer(channels=planes, filt_size=3, stride=2)) |
161 | 127 |
|
162 | | - reduce_layer_planes = max(planes * self.expansion // 8, 64) |
163 | | - self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None |
| 128 | + reduction_chs = max(planes * self.expansion // 8, 64) |
| 129 | + self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None |
164 | 130 |
|
165 | 131 | self.conv3 = conv2d_iabn( |
166 | 132 | planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") |
@@ -189,7 +155,7 @@ def forward(self, x): |
189 | 155 |
|
190 | 156 | class TResNet(nn.Module): |
191 | 157 | def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, |
192 | | - global_pool='avg', drop_rate=0.): |
| 158 | + global_pool='fast', drop_rate=0.): |
193 | 159 | self.num_classes = num_classes |
194 | 160 | self.drop_rate = drop_rate |
195 | 161 | super(TResNet, self).__init__() |
@@ -272,7 +238,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non |
272 | 238 | def get_classifier(self): |
273 | 239 | return self.head.fc |
274 | 240 |
|
275 | | - def reset_classifier(self, num_classes, global_pool='avg'): |
| 241 | + def reset_classifier(self, num_classes, global_pool='fast'): |
276 | 242 | self.head = ClassifierHead( |
277 | 243 | self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) |
278 | 244 |
|
|
0 commit comments