Skip to content

Commit b1a6f4a

Browse files
committed
Some missed reset_classifier() type annotations
1 parent 71101eb commit b1a6f4a

16 files changed

+33
-23
lines changed

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True):
156156
def get_classifier(self) -> nn.Module:
157157
return self.classifier
158158

159-
def reset_classifier(self, num_classes, global_pool='avg'):
159+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
160160
self.num_classes = num_classes
161161
self.global_pool, self.classifier = create_classifier(
162162
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True):
273273
def get_classifier(self) -> nn.Module:
274274
return self.classifier
275275

276-
def reset_classifier(self, num_classes, global_pool='avg'):
276+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
277277
self.num_classes = num_classes
278278
# cannot meaningfully change pooling of efficient head after creation
279279
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)

timm/models/hrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
739739
def get_classifier(self) -> nn.Module:
740740
return self.classifier
741741

742-
def reset_classifier(self, num_classes, global_pool='avg'):
742+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
743743
self.num_classes = num_classes
744744
self.global_pool, self.classifier = create_classifier(
745745
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/inception_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True):
280280
def get_classifier(self) -> nn.Module:
281281
return self.last_linear
282282

283-
def reset_classifier(self, num_classes, global_pool='avg'):
283+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
284284
self.num_classes = num_classes
285285
self.global_pool, self.last_linear = create_classifier(
286286
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/metaformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
2828

29-
3029
from collections import OrderedDict
3130
from functools import partial
31+
from typing import Optional
3232

3333
import torch
3434
import torch.nn as nn
@@ -548,7 +548,7 @@ def __init__(
548548
# if using MlpHead, dropout is handled by MlpHead
549549
if num_classes > 0:
550550
if self.use_mlp_head:
551-
# FIXME hidden size
551+
# FIXME not actually returning mlp hidden state right now as pre-logits.
552552
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
553553
self.head_hidden_size = self.num_features
554554
else:
@@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True):
583583
def get_classifier(self) -> nn.Module:
584584
return self.head.fc
585585

586-
def reset_classifier(self, num_classes=0, global_pool=None):
586+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
587587
if global_pool is not None:
588588
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
589589
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()

timm/models/nasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True):
518518
def get_classifier(self) -> nn.Module:
519519
return self.last_linear
520520

521-
def reset_classifier(self, num_classes, global_pool='avg'):
521+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
522522
self.num_classes = num_classes
523523
self.global_pool, self.last_linear = create_classifier(
524524
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/pnasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True):
307307
def get_classifier(self) -> nn.Module:
308308
return self.last_linear
309309

310-
def reset_classifier(self, num_classes, global_pool='avg'):
310+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
311311
self.num_classes = num_classes
312312
self.global_pool, self.last_linear = create_classifier(
313313
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/regnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True):
514514
def get_classifier(self) -> nn.Module:
515515
return self.head.fc
516516

517-
def reset_classifier(self, num_classes, global_pool='avg'):
517+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
518518
self.head.reset(num_classes, pool_type=global_pool)
519519

520520
def forward_intermediates(

timm/models/rexnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from functools import partial
1414
from math import ceil
15+
from typing import Optional
1516

1617
import torch
1718
import torch.nn as nn
@@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True):
229230
def get_classifier(self) -> nn.Module:
230231
return self.head.fc
231232

232-
def reset_classifier(self, num_classes, global_pool='avg'):
233+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
233234
self.num_classes = num_classes
234235
self.head.reset(num_classes, global_pool)
235236

timm/models/selecsls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True):
161161
def get_classifier(self) -> nn.Module:
162162
return self.fc
163163

164-
def reset_classifier(self, num_classes, global_pool='avg'):
164+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
165165
self.num_classes = num_classes
166166
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
167167

0 commit comments

Comments
 (0)