@@ -42,7 +42,7 @@ def __init__(
4242 qkv_bias : bool = True ,
4343 qkv_separate : bool = False ,
4444 pool_type : str = 'token' ,
45- avg_token : bool = True ,
45+ class_token : bool = False ,
4646 drop_rate : float = 0. ,
4747 ):
4848 super ().__init__ ()
@@ -63,6 +63,11 @@ def __init__(
6363 self .scale = self .head_dim ** - 0.5
6464 self .fused_attn = use_fused_attn ()
6565
66+ if class_token :
67+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
68+ else :
69+ self .cls_token = None
70+
6671 if qkv_separate :
6772 self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
6873 self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
@@ -109,7 +114,10 @@ def forward(self, x, pre_logits: bool = False):
109114 B , _ , H , W = x .shape
110115 N = H * W
111116 x = x .flatten (2 ).transpose (1 , 2 )
112- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
117+ if self .cls_token is None :
118+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
119+ else :
120+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
113121 if self .qkv is None :
114122 q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
115123 k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -130,7 +138,6 @@ def forward(self, x, pre_logits: bool = False):
130138 attn = attn .softmax (dim = - 1 )
131139 x = attn @ v
132140 x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
133- x = x [:, 0 ]
134141 x = self .drop (x )
135142 if pre_logits :
136143 x = self ._pool (x , H , W )
@@ -162,7 +169,7 @@ def __init__(
162169 qkv_bias : bool = True ,
163170 qkv_separate : bool = False ,
164171 pool_type : str = 'token' ,
165- learned_token : bool = False ,
172+ class_token : bool = False ,
166173 drop_rate : float = 0. ,
167174 ):
168175 super ().__init__ ()
@@ -184,10 +191,10 @@ def __init__(
184191 self .scale = self .head_dim ** - 0.5
185192 self .fused_attn = use_fused_attn ()
186193
187- if learned_token :
188- self .token = nn .Parameter (torch .zeros (1 , embed_dim ))
194+ if class_token :
195+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
189196 else :
190- self .token = None
197+ self .cls_token = None
191198
192199 if qkv_separate :
193200 self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
@@ -239,10 +246,10 @@ def forward(self, x, pre_logits: bool = False):
239246 B , _ , H , W = x .shape
240247 N = H * W
241248 x = x .flatten (2 ).transpose (1 , 2 )
242- if self .token is not None :
243- x = torch .cat ([self .token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
244- else :
249+ if self .cls_token is None :
245250 x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
251+ else :
252+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
246253 pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
247254 x = x + pos_embed
248255
0 commit comments