Skip to content

Commit ca52108

Browse files
committed
Fix some model support functions
1 parent f332fc2 commit ca52108

File tree

4 files changed

+17
-7
lines changed

4 files changed

+17
-7
lines changed

timm/models/efficientformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,13 +449,12 @@ def set_grad_checkpointing(self, enable=True):
449449
def get_classifier(self):
450450
return self.head, self.head_dist
451451

452-
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
452+
def reset_classifier(self, num_classes, global_pool=None):
453453
self.num_classes = num_classes
454454
if global_pool is not None:
455455
self.global_pool = global_pool
456456
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
457-
if self.dist:
458-
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
457+
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
459458

460459
@torch.jit.ignore
461460
def set_distilled_training(self, enable=True):

timm/models/gcvit.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def __init__(
427427
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
428428
self.global_pool = global_pool
429429
self.num_classes = num_classes
430+
self.drop_rate = drop_rate
430431
num_stages = len(depths)
431432
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
432433

@@ -491,7 +492,7 @@ def no_weight_decay(self):
491492
def group_matcher(self, coarse=False):
492493
matcher = dict(
493494
stem=r'^stem', # stem and embed
494-
blocks=(r'^stages\.(\d+)', None)
495+
blocks=r'^stages\.(\d+)'
495496
)
496497
return matcher
497498

@@ -500,6 +501,16 @@ def set_grad_checkpointing(self, enable=True):
500501
for s in self.stages:
501502
s.grad_checkpointing = enable
502503

504+
@torch.jit.ignore
505+
def get_classifier(self):
506+
return self.head.fc
507+
508+
def reset_classifier(self, num_classes, global_pool=None):
509+
self.num_classes = num_classes
510+
if global_pool is None:
511+
global_pool = self.head.global_pool.pool_type
512+
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
513+
503514
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
504515
x = self.stem(x)
505516
x = self.stages(x)

timm/models/mvitv2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def no_weight_decay(self):
850850
@torch.jit.ignore
851851
def group_matcher(self, coarse=False):
852852
matcher = dict(
853-
stem=r'^stem', # stem and embed
853+
stem=r'^patch_embed', # stem and embed
854854
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
855855
)
856856
return matcher
@@ -862,7 +862,7 @@ def set_grad_checkpointing(self, enable=True):
862862

863863
@torch.jit.ignore
864864
def get_classifier(self):
865-
return self.head
865+
return self.head.fc
866866

867867
def reset_classifier(self, num_classes, global_pool=None):
868868
self.num_classes = num_classes

timm/models/pvt_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def no_weight_decay(self):
351351
def group_matcher(self, coarse=False):
352352
matcher = dict(
353353
stem=r'^patch_embed', # stem and embed
354-
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
354+
blocks=r'^stages\.(\d+)'
355355
)
356356
return matcher
357357

0 commit comments

Comments
 (0)