File tree Expand file tree Collapse file tree 3 files changed +6
-6
lines changed Expand file tree Collapse file tree 3 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -180,10 +180,10 @@ def __init__(
180180 self .drop = nn .Dropout (drop_rate )
181181 self .fc = linear_layer (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
182182
183- def reset (self , num_classes , global_pool = None ):
184- if global_pool is not None :
185- self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
186- self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
183+ def reset (self , num_classes , pool_type = None ):
184+ if pool_type is not None :
185+ self .global_pool = SelectAdaptivePool2d (pool_type = pool_type )
186+ self .flatten = nn .Flatten (1 ) if pool_type else nn .Identity ()
187187 self .use_conv = self .global_pool .is_identity ()
188188 linear_layer = partial (nn .Conv2d , kernel_size = 1 ) if self .use_conv else nn .Linear
189189 if self .hidden_size :
Original file line number Diff line number Diff line change @@ -569,7 +569,7 @@ def get_classifier(self):
569569 return self .head .fc
570570
571571 def reset_classifier (self , num_classes , global_pool = None ):
572- self .head .reset (num_classes , global_pool = global_pool )
572+ self .head .reset (num_classes , global_pool )
573573
574574 def forward_features (self , x ):
575575 x = self .stem (x )
Original file line number Diff line number Diff line change @@ -535,7 +535,7 @@ def get_classifier(self):
535535
536536 def reset_classifier (self , num_classes , global_pool = None ):
537537 self .num_classes = num_classes
538- self .head .reset (num_classes , global_pool = global_pool )
538+ self .head .reset (num_classes , pool_type = global_pool )
539539
540540 def forward_features (self , x ):
541541 x = self .patch_embed (x )
You can’t perform that action at this time.
0 commit comments