@@ -427,6 +427,7 @@ def __init__(
427427 feat_size = tuple (d // 4 for d in img_size ) # stem reduction by 4
428428 self .global_pool = global_pool
429429 self .num_classes = num_classes
430+ self .drop_rate = drop_rate
430431 num_stages = len (depths )
431432 self .num_features = int (embed_dim * 2 ** (num_stages - 1 ))
432433
@@ -491,7 +492,7 @@ def no_weight_decay(self):
491492 def group_matcher (self , coarse = False ):
492493 matcher = dict (
493494 stem = r'^stem' , # stem and embed
494- blocks = ( r'^stages\.(\d+)' , None )
495+ blocks = r'^stages\.(\d+)'
495496 )
496497 return matcher
497498
@@ -500,6 +501,16 @@ def set_grad_checkpointing(self, enable=True):
500501 for s in self .stages :
501502 s .grad_checkpointing = enable
502503
504+ @torch .jit .ignore
505+ def get_classifier (self ):
506+ return self .head .fc
507+
508+ def reset_classifier (self , num_classes , global_pool = None ):
509+ self .num_classes = num_classes
510+ if global_pool is None :
511+ global_pool = self .head .global_pool .pool_type
512+ self .head = ClassifierHead (self .num_features , num_classes , pool_type = global_pool , drop_rate = self .drop_rate )
513+
503514 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
504515 x = self .stem (x )
505516 x = self .stages (x )
0 commit comments