Skip to content

Commit 9dbea3b

Browse files
committed
fix cls head in hgnet
1 parent 56ae8b9 commit 9dbea3b

File tree

1 file changed

+74
-68
lines changed

1 file changed

+74
-68
lines changed

timm/models/hgnet.py

Lines changed: 74 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -423,15 +423,21 @@ def __init__(
423423
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
424424
self.stages = nn.Sequential(*stages)
425425

426-
self.head = ClassifierHead(
427-
num_features=self.num_features,
428-
num_classes=num_classes,
429-
pool_type=global_pool,
430-
drop_rate=drop_rate,
431-
use_last_conv=use_last_conv,
432-
class_expand=class_expand,
433-
use_lab=use_lab
434-
)
426+
if num_classes > 0:
427+
self.head = ClassifierHead(
428+
num_features=self.num_features,
429+
num_classes=num_classes,
430+
pool_type=global_pool,
431+
drop_rate=drop_rate,
432+
use_last_conv=use_last_conv,
433+
class_expand=class_expand,
434+
use_lab=use_lab
435+
)
436+
else:
437+
if global_pool == 'avg':
438+
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
439+
else:
440+
self.head = nn.Identity()
435441

436442
for n, m in self.named_modules():
437443
if isinstance(m, nn.Conv2d):
@@ -608,65 +614,65 @@ def _cfg(url='', **kwargs):
608614
}
609615

610616

611-
default_cfgs = generate_default_cfgs({
612-
'hgnet_tiny.paddle_in1k': _cfg(
613-
first_conv='stem.0.conv',
614-
hf_hub_id='timm/'),
615-
'hgnet_tiny.ssld_in1k': _cfg(
616-
first_conv='stem.0.conv',
617-
hf_hub_id='timm/'),
618-
'hgnet_small.paddle_in1k': _cfg(
619-
first_conv='stem.0.conv',
620-
hf_hub_id='timm/'),
621-
'hgnet_small.ssld_in1k': _cfg(
622-
first_conv='stem.0.conv',
623-
hf_hub_id='timm/'),
624-
'hgnet_base.ssld_in1k': _cfg(
625-
first_conv='stem.0.conv',
626-
hf_hub_id='timm/'),
627-
'hgnetv2_b0.ssld_in1k': _cfg(
628-
first_conv='stem.stem1.conv',
629-
hf_hub_id='timm/'),
630-
'hgnetv2_b0.ssld_stage1': _cfg(
631-
first_conv='stem.stem1.conv',
632-
hf_hub_id='timm/'),
633-
'hgnetv2_b1.ssld_in1k': _cfg(
634-
first_conv='stem.stem1.conv',
635-
hf_hub_id='timm/'),
636-
'hgnetv2_b1.ssld_stage1': _cfg(
637-
first_conv='stem.stem1.conv',
638-
hf_hub_id='timm/'),
639-
'hgnetv2_b2.ssld_in1k': _cfg(
640-
first_conv='stem.stem1.conv',
641-
hf_hub_id='timm/'),
642-
'hgnetv2_b2.ssld_stage1': _cfg(
643-
first_conv='stem.stem1.conv',
644-
hf_hub_id='timm/'),
645-
'hgnetv2_b3.ssld_in1k': _cfg(
646-
first_conv='stem.stem1.conv',
647-
hf_hub_id='timm/'),
648-
'hgnetv2_b3.ssld_stage1': _cfg(
649-
first_conv='stem.stem1.conv',
650-
hf_hub_id='timm/'),
651-
'hgnetv2_b4.ssld_in1k': _cfg(
652-
first_conv='stem.stem1.conv',
653-
hf_hub_id='timm/'),
654-
'hgnetv2_b4.ssld_stage1': _cfg(
655-
first_conv='stem.stem1.conv',
656-
hf_hub_id='timm/'),
657-
'hgnetv2_b5.ssld_in1k': _cfg(
658-
first_conv='stem.stem1.conv',
659-
hf_hub_id='timm/'),
660-
'hgnetv2_b5.ssld_stage1': _cfg(
661-
first_conv='stem.stem1.conv',
662-
hf_hub_id='timm/'),
663-
'hgnetv2_b6.ssld_in1k': _cfg(
664-
first_conv='stem.stem1.conv',
665-
hf_hub_id='timm/'),
666-
'hgnetv2_b6.ssld_stage1': _cfg(
667-
first_conv='stem.stem1.conv',
668-
hf_hub_id='timm/'),
669-
})
617+
# default_cfgs = generate_default_cfgs({
618+
# 'hgnet_tiny.paddle_in1k': _cfg(
619+
# first_conv='stem.0.conv',
620+
# hf_hub_id='timm/'),
621+
# 'hgnet_tiny.ssld_in1k': _cfg(
622+
# first_conv='stem.0.conv',
623+
# hf_hub_id='timm/'),
624+
# 'hgnet_small.paddle_in1k': _cfg(
625+
# first_conv='stem.0.conv',
626+
# hf_hub_id='timm/'),
627+
# 'hgnet_small.ssld_in1k': _cfg(
628+
# first_conv='stem.0.conv',
629+
# hf_hub_id='timm/'),
630+
# 'hgnet_base.ssld_in1k': _cfg(
631+
# first_conv='stem.0.conv',
632+
# hf_hub_id='timm/'),
633+
# 'hgnetv2_b0.ssld_in1k': _cfg(
634+
# first_conv='stem.stem1.conv',
635+
# hf_hub_id='timm/'),
636+
# 'hgnetv2_b0.ssld_stage1': _cfg(
637+
# first_conv='stem.stem1.conv',
638+
# hf_hub_id='timm/'),
639+
# 'hgnetv2_b1.ssld_in1k': _cfg(
640+
# first_conv='stem.stem1.conv',
641+
# hf_hub_id='timm/'),
642+
# 'hgnetv2_b1.ssld_stage1': _cfg(
643+
# first_conv='stem.stem1.conv',
644+
# hf_hub_id='timm/'),
645+
# 'hgnetv2_b2.ssld_in1k': _cfg(
646+
# first_conv='stem.stem1.conv',
647+
# hf_hub_id='timm/'),
648+
# 'hgnetv2_b2.ssld_stage1': _cfg(
649+
# first_conv='stem.stem1.conv',
650+
# hf_hub_id='timm/'),
651+
# 'hgnetv2_b3.ssld_in1k': _cfg(
652+
# first_conv='stem.stem1.conv',
653+
# hf_hub_id='timm/'),
654+
# 'hgnetv2_b3.ssld_stage1': _cfg(
655+
# first_conv='stem.stem1.conv',
656+
# hf_hub_id='timm/'),
657+
# 'hgnetv2_b4.ssld_in1k': _cfg(
658+
# first_conv='stem.stem1.conv',
659+
# hf_hub_id='timm/'),
660+
# 'hgnetv2_b4.ssld_stage1': _cfg(
661+
# first_conv='stem.stem1.conv',
662+
# hf_hub_id='timm/'),
663+
# 'hgnetv2_b5.ssld_in1k': _cfg(
664+
# first_conv='stem.stem1.conv',
665+
# hf_hub_id='timm/'),
666+
# 'hgnetv2_b5.ssld_stage1': _cfg(
667+
# first_conv='stem.stem1.conv',
668+
# hf_hub_id='timm/'),
669+
# 'hgnetv2_b6.ssld_in1k': _cfg(
670+
# first_conv='stem.stem1.conv',
671+
# hf_hub_id='timm/'),
672+
# 'hgnetv2_b6.ssld_stage1': _cfg(
673+
# first_conv='stem.stem1.conv',
674+
# hf_hub_id='timm/'),
675+
# })
670676

671677

672678
@register_model

0 commit comments

Comments
 (0)