Skip to content

Commit 6cc11a8

Browse files
authored
Merge pull request #141 from Animatory/fix_HRNet
Fixed HRNet modules
2 parents 3b72ebf + f0eb021 commit 6cc11a8

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

timm/models/hrnet.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,16 @@
1313
from __future__ import division
1414
from __future__ import print_function
1515

16-
import os
1716
import logging
18-
import functools
1917

20-
import torch
2118
import torch.nn as nn
22-
import torch._utils
2319
import torch.nn.functional as F
2420

25-
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
26-
from .registry import register_model
21+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2722
from .helpers import load_pretrained
2823
from .layers import SelectAdaptivePool2d
29-
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
24+
from .registry import register_model
25+
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
3026

3127
_BN_MOMENTUM = 0.1
3228
logger = logging.getLogger(__name__)
@@ -101,7 +97,7 @@ def _cfg(url='', **kwargs):
10197
),
10298
),
10399

104-
hrnet_w18_small_v2 = dict(
100+
hrnet_w18_small_v2=dict(
105101
STEM_WIDTH=64,
106102
STAGE1=dict(
107103
NUM_MODULES=1,
@@ -137,7 +133,7 @@ def _cfg(url='', **kwargs):
137133
),
138134
),
139135

140-
hrnet_w18 = dict(
136+
hrnet_w18=dict(
141137
STEM_WIDTH=64,
142138
STAGE1=dict(
143139
NUM_MODULES=1,
@@ -173,7 +169,7 @@ def _cfg(url='', **kwargs):
173169
),
174170
),
175171

176-
hrnet_w30 = dict(
172+
hrnet_w30=dict(
177173
STEM_WIDTH=64,
178174
STAGE1=dict(
179175
NUM_MODULES=1,
@@ -209,7 +205,7 @@ def _cfg(url='', **kwargs):
209205
),
210206
),
211207

212-
hrnet_w32 = dict(
208+
hrnet_w32=dict(
213209
STEM_WIDTH=64,
214210
STAGE1=dict(
215211
NUM_MODULES=1,
@@ -245,7 +241,7 @@ def _cfg(url='', **kwargs):
245241
),
246242
),
247243

248-
hrnet_w40 = dict(
244+
hrnet_w40=dict(
249245
STEM_WIDTH=64,
250246
STAGE1=dict(
251247
NUM_MODULES=1,
@@ -281,7 +277,7 @@ def _cfg(url='', **kwargs):
281277
),
282278
),
283279

284-
hrnet_w44 = dict(
280+
hrnet_w44=dict(
285281
STEM_WIDTH=64,
286282
STAGE1=dict(
287283
NUM_MODULES=1,
@@ -317,7 +313,7 @@ def _cfg(url='', **kwargs):
317313
),
318314
),
319315

320-
hrnet_w48 = dict(
316+
hrnet_w48=dict(
321317
STEM_WIDTH=64,
322318
STAGE1=dict(
323319
NUM_MODULES=1,
@@ -353,7 +349,7 @@ def _cfg(url='', **kwargs):
353349
),
354350
),
355351

356-
hrnet_w64 = dict(
352+
hrnet_w64=dict(
357353
STEM_WIDTH=64,
358354
STAGE1=dict(
359355
NUM_MODULES=1,
@@ -456,7 +452,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):
456452

457453
def _make_fuse_layers(self):
458454
if self.num_branches == 1:
459-
return None
455+
return nn.Identity()
460456

461457
num_branches = self.num_branches
462458
num_inchannels = self.num_inchannels
@@ -470,7 +466,7 @@ def _make_fuse_layers(self):
470466
nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM),
471467
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
472468
elif j == i:
473-
fuse_layer.append(None)
469+
fuse_layer.append(nn.Identity())
474470
else:
475471
conv3x3s = []
476472
for k in range(i - j):
@@ -619,7 +615,7 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer)
619615
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
620616
nn.ReLU(inplace=True)))
621617
else:
622-
transition_layers.append(None)
618+
transition_layers.append(nn.Identity())
623619
else:
624620
conv3x3s = []
625621
for j in range(i + 1 - num_branches_pre):
@@ -686,8 +682,11 @@ def get_classifier(self):
686682
def reset_classifier(self, num_classes, global_pool='avg'):
687683
self.num_classes = num_classes
688684
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
689-
self.classifier = nn.Linear(
690-
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
685+
num_features = self.num_features * self.global_pool.feat_mult()
686+
if num_classes:
687+
self.classifier = nn.Linear(num_features, num_classes)
688+
else:
689+
self.classifier = nn.Identity()
691690

692691
def forward_features(self, x):
693692
x = self.conv1(x)
@@ -699,24 +698,21 @@ def forward_features(self, x):
699698
x = self.layer1(x)
700699

701700
x_list = []
702-
for i in range(self.stage2_cfg['NUM_BRANCHES']):
703-
if self.transition1[i] is not None:
704-
x_list.append(self.transition1[i](x))
705-
else:
706-
x_list.append(x)
701+
for i in range(len(self.transition1)):
702+
x_list.append(self.transition1[i](x))
707703
y_list = self.stage2(x_list)
708704

709705
x_list = []
710-
for i in range(self.stage3_cfg['NUM_BRANCHES']):
711-
if self.transition2[i] is not None:
706+
for i in range(len(self.transition2)):
707+
if not isinstance(self.transition2[i], nn.Identity):
712708
x_list.append(self.transition2[i](y_list[-1]))
713709
else:
714710
x_list.append(y_list[i])
715711
y_list = self.stage3(x_list)
716712

717713
x_list = []
718-
for i in range(self.stage4_cfg['NUM_BRANCHES']):
719-
if self.transition3[i] is not None:
714+
for i in range(len(self.transition3)):
715+
if not isinstance(self.transition3[i], nn.Identity):
720716
x_list.append(self.transition3[i](y_list[-1]))
721717
else:
722718
x_list.append(y_list[i])

0 commit comments

Comments
 (0)