Skip to content

Commit 56ae8b9

Browse files
authored
fix reset head in hgnet
1 parent 6862c98 commit 56ae8b9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

timm/models/hgnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ def reset_classifier(self, num_classes, global_pool='avg'):
470470
class_expand=self.class_expand,
471471
use_lab=self.use_lab)
472472
else:
473-
if self.global_pool == 'avg':
474-
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
473+
if global_pool == 'avg':
474+
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
475475
else:
476476
self.head = nn.Identity()
477477

@@ -480,7 +480,7 @@ def forward_features(self, x):
480480
return self.stages(x)
481481

482482
def forward_head(self, x, pre_logits: bool = False):
483-
return self.head(x, pre_logits=pre_logits)
483+
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
484484

485485
def forward(self, x):
486486
x = self.forward_features(x)

0 commit comments

Comments
 (0)