@@ -423,15 +423,21 @@ def __init__(
423423 self .feature_info += [dict (num_chs = self .num_features , reduction = current_stride , module = f'stages.{ i } ' )]
424424 self .stages = nn .Sequential (* stages )
425425
426- self .head = ClassifierHead (
427- num_features = self .num_features ,
428- num_classes = num_classes ,
429- pool_type = global_pool ,
430- drop_rate = drop_rate ,
431- use_last_conv = use_last_conv ,
432- class_expand = class_expand ,
433- use_lab = use_lab
434- )
426+ if num_classes > 0 :
427+ self .head = ClassifierHead (
428+ num_features = self .num_features ,
429+ num_classes = num_classes ,
430+ pool_type = global_pool ,
431+ drop_rate = drop_rate ,
432+ use_last_conv = use_last_conv ,
433+ class_expand = class_expand ,
434+ use_lab = use_lab
435+ )
436+ else :
437+ if global_pool == 'avg' :
438+ self .head = SelectAdaptivePool2d (pool_type = global_pool , flatten = True )
439+ else :
440+ self .head = nn .Identity ()
435441
436442 for n , m in self .named_modules ():
437443 if isinstance (m , nn .Conv2d ):
@@ -608,65 +614,65 @@ def _cfg(url='', **kwargs):
608614 }
609615
610616
611- default_cfgs = generate_default_cfgs ({
612- 'hgnet_tiny.paddle_in1k' : _cfg (
613- first_conv = 'stem.0.conv' ,
614- hf_hub_id = 'timm/' ),
615- 'hgnet_tiny.ssld_in1k' : _cfg (
616- first_conv = 'stem.0.conv' ,
617- hf_hub_id = 'timm/' ),
618- 'hgnet_small.paddle_in1k' : _cfg (
619- first_conv = 'stem.0.conv' ,
620- hf_hub_id = 'timm/' ),
621- 'hgnet_small.ssld_in1k' : _cfg (
622- first_conv = 'stem.0.conv' ,
623- hf_hub_id = 'timm/' ),
624- 'hgnet_base.ssld_in1k' : _cfg (
625- first_conv = 'stem.0.conv' ,
626- hf_hub_id = 'timm/' ),
627- 'hgnetv2_b0.ssld_in1k' : _cfg (
628- first_conv = 'stem.stem1.conv' ,
629- hf_hub_id = 'timm/' ),
630- 'hgnetv2_b0.ssld_stage1' : _cfg (
631- first_conv = 'stem.stem1.conv' ,
632- hf_hub_id = 'timm/' ),
633- 'hgnetv2_b1.ssld_in1k' : _cfg (
634- first_conv = 'stem.stem1.conv' ,
635- hf_hub_id = 'timm/' ),
636- 'hgnetv2_b1.ssld_stage1' : _cfg (
637- first_conv = 'stem.stem1.conv' ,
638- hf_hub_id = 'timm/' ),
639- 'hgnetv2_b2.ssld_in1k' : _cfg (
640- first_conv = 'stem.stem1.conv' ,
641- hf_hub_id = 'timm/' ),
642- 'hgnetv2_b2.ssld_stage1' : _cfg (
643- first_conv = 'stem.stem1.conv' ,
644- hf_hub_id = 'timm/' ),
645- 'hgnetv2_b3.ssld_in1k' : _cfg (
646- first_conv = 'stem.stem1.conv' ,
647- hf_hub_id = 'timm/' ),
648- 'hgnetv2_b3.ssld_stage1' : _cfg (
649- first_conv = 'stem.stem1.conv' ,
650- hf_hub_id = 'timm/' ),
651- 'hgnetv2_b4.ssld_in1k' : _cfg (
652- first_conv = 'stem.stem1.conv' ,
653- hf_hub_id = 'timm/' ),
654- 'hgnetv2_b4.ssld_stage1' : _cfg (
655- first_conv = 'stem.stem1.conv' ,
656- hf_hub_id = 'timm/' ),
657- 'hgnetv2_b5.ssld_in1k' : _cfg (
658- first_conv = 'stem.stem1.conv' ,
659- hf_hub_id = 'timm/' ),
660- 'hgnetv2_b5.ssld_stage1' : _cfg (
661- first_conv = 'stem.stem1.conv' ,
662- hf_hub_id = 'timm/' ),
663- 'hgnetv2_b6.ssld_in1k' : _cfg (
664- first_conv = 'stem.stem1.conv' ,
665- hf_hub_id = 'timm/' ),
666- 'hgnetv2_b6.ssld_stage1' : _cfg (
667- first_conv = 'stem.stem1.conv' ,
668- hf_hub_id = 'timm/' ),
669- })
617+ # default_cfgs = generate_default_cfgs({
618+ # 'hgnet_tiny.paddle_in1k': _cfg(
619+ # first_conv='stem.0.conv',
620+ # hf_hub_id='timm/'),
621+ # 'hgnet_tiny.ssld_in1k': _cfg(
622+ # first_conv='stem.0.conv',
623+ # hf_hub_id='timm/'),
624+ # 'hgnet_small.paddle_in1k': _cfg(
625+ # first_conv='stem.0.conv',
626+ # hf_hub_id='timm/'),
627+ # 'hgnet_small.ssld_in1k': _cfg(
628+ # first_conv='stem.0.conv',
629+ # hf_hub_id='timm/'),
630+ # 'hgnet_base.ssld_in1k': _cfg(
631+ # first_conv='stem.0.conv',
632+ # hf_hub_id='timm/'),
633+ # 'hgnetv2_b0.ssld_in1k': _cfg(
634+ # first_conv='stem.stem1.conv',
635+ # hf_hub_id='timm/'),
636+ # 'hgnetv2_b0.ssld_stage1': _cfg(
637+ # first_conv='stem.stem1.conv',
638+ # hf_hub_id='timm/'),
639+ # 'hgnetv2_b1.ssld_in1k': _cfg(
640+ # first_conv='stem.stem1.conv',
641+ # hf_hub_id='timm/'),
642+ # 'hgnetv2_b1.ssld_stage1': _cfg(
643+ # first_conv='stem.stem1.conv',
644+ # hf_hub_id='timm/'),
645+ # 'hgnetv2_b2.ssld_in1k': _cfg(
646+ # first_conv='stem.stem1.conv',
647+ # hf_hub_id='timm/'),
648+ # 'hgnetv2_b2.ssld_stage1': _cfg(
649+ # first_conv='stem.stem1.conv',
650+ # hf_hub_id='timm/'),
651+ # 'hgnetv2_b3.ssld_in1k': _cfg(
652+ # first_conv='stem.stem1.conv',
653+ # hf_hub_id='timm/'),
654+ # 'hgnetv2_b3.ssld_stage1': _cfg(
655+ # first_conv='stem.stem1.conv',
656+ # hf_hub_id='timm/'),
657+ # 'hgnetv2_b4.ssld_in1k': _cfg(
658+ # first_conv='stem.stem1.conv',
659+ # hf_hub_id='timm/'),
660+ # 'hgnetv2_b4.ssld_stage1': _cfg(
661+ # first_conv='stem.stem1.conv',
662+ # hf_hub_id='timm/'),
663+ # 'hgnetv2_b5.ssld_in1k': _cfg(
664+ # first_conv='stem.stem1.conv',
665+ # hf_hub_id='timm/'),
666+ # 'hgnetv2_b5.ssld_stage1': _cfg(
667+ # first_conv='stem.stem1.conv',
668+ # hf_hub_id='timm/'),
669+ # 'hgnetv2_b6.ssld_in1k': _cfg(
670+ # first_conv='stem.stem1.conv',
671+ # hf_hub_id='timm/'),
672+ # 'hgnetv2_b6.ssld_stage1': _cfg(
673+ # first_conv='stem.stem1.conv',
674+ # hf_hub_id='timm/'),
675+ # })
670676
671677
672678@register_model
0 commit comments