1313from __future__ import division
1414from __future__ import print_function
1515
16- import os
1716import logging
18- import functools
1917
20- import torch
2118import torch .nn as nn
22- import torch ._utils
2319import torch .nn .functional as F
2420
25- from .resnet import BasicBlock , Bottleneck # leveraging ResNet blocks w/ additional features like SE
26- from .registry import register_model
21+ from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2722from .helpers import load_pretrained
2823from .layers import SelectAdaptivePool2d
29- from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
24+ from .registry import register_model
25+ from .resnet import BasicBlock , Bottleneck # leveraging ResNet blocks w/ additional features like SE
3026
3127_BN_MOMENTUM = 0.1
3228logger = logging .getLogger (__name__ )
@@ -101,7 +97,7 @@ def _cfg(url='', **kwargs):
10197 ),
10298 ),
10399
104- hrnet_w18_small_v2 = dict (
100+ hrnet_w18_small_v2 = dict (
105101 STEM_WIDTH = 64 ,
106102 STAGE1 = dict (
107103 NUM_MODULES = 1 ,
@@ -137,7 +133,7 @@ def _cfg(url='', **kwargs):
137133 ),
138134 ),
139135
140- hrnet_w18 = dict (
136+ hrnet_w18 = dict (
141137 STEM_WIDTH = 64 ,
142138 STAGE1 = dict (
143139 NUM_MODULES = 1 ,
@@ -173,7 +169,7 @@ def _cfg(url='', **kwargs):
173169 ),
174170 ),
175171
176- hrnet_w30 = dict (
172+ hrnet_w30 = dict (
177173 STEM_WIDTH = 64 ,
178174 STAGE1 = dict (
179175 NUM_MODULES = 1 ,
@@ -209,7 +205,7 @@ def _cfg(url='', **kwargs):
209205 ),
210206 ),
211207
212- hrnet_w32 = dict (
208+ hrnet_w32 = dict (
213209 STEM_WIDTH = 64 ,
214210 STAGE1 = dict (
215211 NUM_MODULES = 1 ,
@@ -245,7 +241,7 @@ def _cfg(url='', **kwargs):
245241 ),
246242 ),
247243
248- hrnet_w40 = dict (
244+ hrnet_w40 = dict (
249245 STEM_WIDTH = 64 ,
250246 STAGE1 = dict (
251247 NUM_MODULES = 1 ,
@@ -281,7 +277,7 @@ def _cfg(url='', **kwargs):
281277 ),
282278 ),
283279
284- hrnet_w44 = dict (
280+ hrnet_w44 = dict (
285281 STEM_WIDTH = 64 ,
286282 STAGE1 = dict (
287283 NUM_MODULES = 1 ,
@@ -317,7 +313,7 @@ def _cfg(url='', **kwargs):
317313 ),
318314 ),
319315
320- hrnet_w48 = dict (
316+ hrnet_w48 = dict (
321317 STEM_WIDTH = 64 ,
322318 STAGE1 = dict (
323319 NUM_MODULES = 1 ,
@@ -353,7 +349,7 @@ def _cfg(url='', **kwargs):
353349 ),
354350 ),
355351
356- hrnet_w64 = dict (
352+ hrnet_w64 = dict (
357353 STEM_WIDTH = 64 ,
358354 STAGE1 = dict (
359355 NUM_MODULES = 1 ,
@@ -456,7 +452,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):
456452
457453 def _make_fuse_layers (self ):
458454 if self .num_branches == 1 :
459- return None
455+ return nn . Identity ()
460456
461457 num_branches = self .num_branches
462458 num_inchannels = self .num_inchannels
@@ -470,7 +466,7 @@ def _make_fuse_layers(self):
470466 nn .BatchNorm2d (num_inchannels [i ], momentum = _BN_MOMENTUM ),
471467 nn .Upsample (scale_factor = 2 ** (j - i ), mode = 'nearest' )))
472468 elif j == i :
473- fuse_layer .append (None )
469+ fuse_layer .append (nn . Identity () )
474470 else :
475471 conv3x3s = []
476472 for k in range (i - j ):
@@ -619,7 +615,7 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer)
619615 nn .BatchNorm2d (num_channels_cur_layer [i ], momentum = _BN_MOMENTUM ),
620616 nn .ReLU (inplace = True )))
621617 else :
622- transition_layers .append (None )
618+ transition_layers .append (nn . Identity () )
623619 else :
624620 conv3x3s = []
625621 for j in range (i + 1 - num_branches_pre ):
@@ -686,8 +682,11 @@ def get_classifier(self):
686682 def reset_classifier (self , num_classes , global_pool = 'avg' ):
687683 self .num_classes = num_classes
688684 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
689- self .classifier = nn .Linear (
690- self .num_features * self .global_pool .feat_mult (), num_classes ) if num_classes else None
685+ num_features = self .num_features * self .global_pool .feat_mult ()
686+ if num_classes :
687+ self .classifier = nn .Linear (num_features , num_classes )
688+ else :
689+ self .classifier = nn .Identity ()
691690
692691 def forward_features (self , x ):
693692 x = self .conv1 (x )
@@ -699,24 +698,21 @@ def forward_features(self, x):
699698 x = self .layer1 (x )
700699
701700 x_list = []
702- for i in range (self .stage2_cfg ['NUM_BRANCHES' ]):
703- if self .transition1 [i ] is not None :
704- x_list .append (self .transition1 [i ](x ))
705- else :
706- x_list .append (x )
701+ for i in range (len (self .transition1 )):
702+ x_list .append (self .transition1 [i ](x ))
707703 y_list = self .stage2 (x_list )
708704
709705 x_list = []
710- for i in range (self .stage3_cfg [ 'NUM_BRANCHES' ] ):
711- if self .transition2 [i ] is not None :
706+ for i in range (len ( self .transition2 ) ):
707+ if not isinstance ( self .transition2 [i ], nn . Identity ) :
712708 x_list .append (self .transition2 [i ](y_list [- 1 ]))
713709 else :
714710 x_list .append (y_list [i ])
715711 y_list = self .stage3 (x_list )
716712
717713 x_list = []
718- for i in range (self .stage4_cfg [ 'NUM_BRANCHES' ] ):
719- if self .transition3 [i ] is not None :
714+ for i in range (len ( self .transition3 ) ):
715+ if not isinstance ( self .transition3 [i ], nn . Identity ) :
720716 x_list .append (self .transition3 [i ](y_list [- 1 ]))
721717 else :
722718 x_list .append (y_list [i ])
0 commit comments