@@ -41,9 +41,12 @@ def __init__(
4141 num_heads : Optional [int ] = None ,
4242 qkv_bias : bool = True ,
4343 qkv_separate : bool = False ,
44- drop : float = 0. ,
44+ pool_type : str = 'token' ,
45+ avg_token : bool = True ,
46+ drop_rate : float = 0. ,
4547 ):
4648 super ().__init__ ()
49+ assert pool_type in ('' , 'token' )
4750 self .embed_dim = embed_dim = embed_dim or in_features
4851 self .in_features = in_features
4952 self .out_features = out_features or in_features
@@ -56,6 +59,7 @@ def __init__(
5659 num_heads = embed_dim // head_dim
5760 self .num_heads = num_heads
5861 self .head_dim = head_dim
62+ self .pool_type = pool_type .lower ()
5963 self .scale = self .head_dim ** - 0.5
6064 self .fused_attn = use_fused_attn ()
6165
@@ -66,6 +70,7 @@ def __init__(
6670 self .qkv = None
6771 else :
6872 self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
73+ self .drop = nn .Dropout (drop_rate )
6974 self .proj = nn .Linear (embed_dim , self .out_features )
7075 self .pos_embed = RotaryEmbedding (self .head_dim , in_pixels = False , ref_feat_shape = ref_feat_size )
7176
@@ -83,6 +88,23 @@ def init_weights(self, zero_init_last: bool = False):
8388 trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
8489 nn .init .zeros_ (self .qkv .bias )
8590
91+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
92+ # NOTE: this module is being used as a head, so need compatible reset()
93+ if pool_type is not None :
94+ assert pool_type in ('' , 'token' )
95+ self .pool_type = pool_type
96+ if num_classes is not None :
97+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
98+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
99+
100+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
101+ if self .pool_type == 'token' :
102+ x = x [:, 0 ]
103+ else :
104+ # if not pooled, return spatial output without token
105+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
106+ return x
107+
86108 def forward (self , x , pre_logits : bool = False ):
87109 B , _ , H , W = x .shape
88110 N = H * W
@@ -111,8 +133,10 @@ def forward(self, x, pre_logits: bool = False):
111133 x = x [:, 0 ]
112134 x = self .drop (x )
113135 if pre_logits :
136+ x = self ._pool (x , H , W )
114137 return x
115138 x = self .proj (x )
139+ x = self ._pool (x , H , W )
116140 return x
117141
118142
@@ -137,9 +161,12 @@ def __init__(
137161 num_heads : Optional [int ] = None ,
138162 qkv_bias : bool = True ,
139163 qkv_separate : bool = False ,
140- drop : float = 0. ,
164+ pool_type : str = 'token' ,
165+ learned_token : bool = False ,
166+ drop_rate : float = 0. ,
141167 ):
142168 super ().__init__ ()
169+ assert pool_type in ('' , 'token' )
143170 self .embed_dim = embed_dim = embed_dim or in_features
144171 self .in_features = in_features
145172 self .out_features = out_features or in_features
@@ -153,9 +180,15 @@ def __init__(
153180 self .seq_len = self .feat_size [0 ] * self .feat_size [1 ]
154181 self .num_heads = num_heads
155182 self .head_dim = head_dim
183+ self .pool_type = pool_type
156184 self .scale = self .head_dim ** - 0.5
157185 self .fused_attn = use_fused_attn ()
158186
187+ if learned_token :
188+ self .token = nn .Parameter (torch .zeros (1 , embed_dim ))
189+ else :
190+ self .token = None
191+
159192 if qkv_separate :
160193 self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
161194 self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
@@ -164,7 +197,7 @@ def __init__(
164197 else :
165198 self .q = self .k = self .v = None
166199 self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
167- self .drop = nn .Dropout (drop )
200+ self .drop = nn .Dropout (drop_rate )
168201 self .proj = nn .Linear (embed_dim , self .out_features )
169202 self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
170203
@@ -185,11 +218,31 @@ def init_weights(self, zero_init_last: bool = False):
185218 nn .init .zeros_ (self .qkv .bias )
186219 trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
187220
221+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
222+ # NOTE: this module is being used as a head, so need compatible reset()
223+ if pool_type is not None :
224+ assert pool_type in ('' , 'token' )
225+ self .pool_type = pool_type
226+ if num_classes is not None :
227+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
228+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
229+
230+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
231+ if self .pool_type == 'token' :
232+ x = x [:, 0 ]
233+ else :
234+ # if not pooled, return spatial output without token
235+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
236+ return x
237+
188238 def forward (self , x , pre_logits : bool = False ):
189239 B , _ , H , W = x .shape
190240 N = H * W
191241 x = x .flatten (2 ).transpose (1 , 2 )
192- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
242+ if self .token is not None :
243+ x = torch .cat ([self .token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
244+ else :
245+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
193246 pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
194247 x = x + pos_embed
195248
@@ -209,9 +262,10 @@ def forward(self, x, pre_logits: bool = False):
209262 attn = attn .softmax (dim = - 1 )
210263 x = attn @ v
211264 x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
212- x = x [:, 0 ]
213265 x = self .drop (x )
214266 if pre_logits :
267+ x = self ._pool (x , H , W )
215268 return x
216269 x = self .proj (x )
270+ x = self ._pool (x , H , W )
217271 return x
0 commit comments