Skip to content

Commit b71d60c

Browse files
committed
Two small fixes, num_classes in base class, add model tag
1 parent 3318e76 commit b71d60c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

timm/models/repvit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def __init__(self, dim, num_classes, distillation=False):
178178
self.distillation = distillation
179179
if distillation:
180180
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
181-
self.num_classes = num_classes
182181

183182
def forward(self, x):
184183
if self.distillation:
@@ -248,6 +247,7 @@ def __init__(
248247
self.grad_checkpointing = False
249248
self.global_pool = global_pool
250249
self.embed_dim = embed_dim
250+
self.num_classes = num_classes
251251

252252
in_dim = embed_dim[0]
253253
self.stem = RepViTStem(in_chans, in_dim, act_layer)
@@ -356,13 +356,13 @@ def _cfg(url='', **kwargs):
356356

357357
default_cfgs = generate_default_cfgs(
358358
{
359-
'repvit_m1': _cfg(
359+
'repvit_m1.dist_in1k': _cfg(
360360
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
361361
),
362-
'repvit_m2': _cfg(
362+
'repvit_m2.dist_in1k': _cfg(
363363
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
364364
),
365-
'repvit_m3': _cfg(
365+
'repvit_m3.dist_in1k': _cfg(
366366
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
367367
),
368368
}

0 commit comments

Comments
 (0)