Skip to content

Commit d6c2cc9

Browse files
committed
Make NormMlpClassifier head reset args consistent with ClassifierHead
1 parent 87fec3d commit d6c2cc9

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

timm/layers/classifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ def __init__(
180180
self.drop = nn.Dropout(drop_rate)
181181
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
182182

183-
def reset(self, num_classes, global_pool=None):
184-
if global_pool is not None:
185-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
186-
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
183+
def reset(self, num_classes, pool_type=None):
184+
if pool_type is not None:
185+
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
186+
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
187187
self.use_conv = self.global_pool.is_identity()
188188
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
189189
if self.hidden_size:

timm/models/davit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def get_classifier(self):
569569
return self.head.fc
570570

571571
def reset_classifier(self, num_classes, global_pool=None):
572-
self.head.reset(num_classes, global_pool=global_pool)
572+
self.head.reset(num_classes, global_pool)
573573

574574
def forward_features(self, x):
575575
x = self.stem(x)

timm/models/tiny_vit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def get_classifier(self):
535535

536536
def reset_classifier(self, num_classes, global_pool=None):
537537
self.num_classes = num_classes
538-
self.head.reset(num_classes, global_pool=global_pool)
538+
self.head.reset(num_classes, pool_type=global_pool)
539539

540540
def forward_features(self, x):
541541
x = self.patch_embed(x)

0 commit comments

Comments
 (0)