@@ -400,7 +400,7 @@ def __init__(
400400 patch_size : Union [int , Tuple [int , int ]] = 16 ,
401401 in_chans : int = 3 ,
402402 num_classes : int = 1000 ,
403- global_pool : Literal ['' , 'avg' , 'token' , 'map' ] = 'token' ,
403+ global_pool : Literal ['' , 'avg' , 'max' , ' token' , 'map' ] = 'token' ,
404404 embed_dim : int = 768 ,
405405 depth : int = 12 ,
406406 num_heads : int = 12 ,
@@ -459,10 +459,10 @@ def __init__(
459459 block_fn: Transformer block layer.
460460 """
461461 super ().__init__ ()
462- assert global_pool in ('' , 'avg' , 'token' , 'map' )
462+ assert global_pool in ('' , 'avg' , 'max' , ' token' , 'map' )
463463 assert class_token or global_pool != 'token'
464464 assert pos_embed in ('' , 'none' , 'learn' )
465- use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
465+ use_fc_norm = global_pool in [ 'avg' , 'max' ] if fc_norm is None else fc_norm
466466 norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
467467 act_layer = get_act_layer (act_layer ) or nn .GELU
468468
@@ -761,6 +761,8 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
761761 x = self .attn_pool (x )
762762 elif self .global_pool == 'avg' :
763763 x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
764+ elif self .global_pool == 'max' :
765+ x , _ = torch .max (x [:, self .num_prefix_tokens :], dim = 1 )
764766 elif self .global_pool :
765767 x = x [:, 0 ] # class token
766768 x = self .fc_norm (x )
0 commit comments