@@ -41,9 +41,10 @@ def __init__(
4141 num_heads : Optional [int ] = None ,
4242 qkv_bias : bool = True ,
4343 qkv_separate : bool = False ,
44+ drop : float = 0. ,
4445 ):
4546 super ().__init__ ()
46- embed_dim = embed_dim or in_features
47+ self . embed_dim = embed_dim = embed_dim or in_features
4748 self .in_features = in_features
4849 self .out_features = out_features or in_features
4950 ref_feat_size = to_2tuple (ref_feat_size )
@@ -82,7 +83,7 @@ def init_weights(self, zero_init_last: bool = False):
8283 trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
8384 nn .init .zeros_ (self .qkv .bias )
8485
85- def forward (self , x ):
86+ def forward (self , x , pre_logits : bool = False ):
8687 B , _ , H , W = x .shape
8788 N = H * W
8889 x = x .flatten (2 ).transpose (1 , 2 )
@@ -107,8 +108,12 @@ def forward(self, x):
107108 attn = attn .softmax (dim = - 1 )
108109 x = attn @ v
109110 x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
111+ x = x [:, 0 ]
112+ x = self .drop (x )
113+ if pre_logits :
114+ return x
110115 x = self .proj (x )
111- return x [:, 0 ]
116+ return x
112117
113118
114119class AttentionPool2d (nn .Module ):
@@ -132,9 +137,10 @@ def __init__(
132137 num_heads : Optional [int ] = None ,
133138 qkv_bias : bool = True ,
134139 qkv_separate : bool = False ,
140+ drop : float = 0. ,
135141 ):
136142 super ().__init__ ()
137- embed_dim = embed_dim or in_features
143+ self . embed_dim = embed_dim = embed_dim or in_features
138144 self .in_features = in_features
139145 self .out_features = out_features or in_features
140146 if num_heads is not None :
@@ -158,6 +164,7 @@ def __init__(
158164 else :
159165 self .q = self .k = self .v = None
160166 self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
167+ self .drop = nn .Dropout (drop )
161168 self .proj = nn .Linear (embed_dim , self .out_features )
162169 self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
163170
@@ -178,15 +185,12 @@ def init_weights(self, zero_init_last: bool = False):
178185 nn .init .zeros_ (self .qkv .bias )
179186 trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
180187
181- def forward (self , x ):
188+ def forward (self , x , pre_logits : bool = False ):
182189 B , _ , H , W = x .shape
183190 N = H * W
184191 x = x .flatten (2 ).transpose (1 , 2 )
185192 x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
186- if self .seq_len != N :
187- pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
188- else :
189- pos_embed = self .pos_embed .unsqueeze (0 ).to (x .dtype )
193+ pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
190194 x = x + pos_embed
191195
192196 if self .qkv is None :
@@ -205,5 +209,9 @@ def forward(self, x):
205209 attn = attn .softmax (dim = - 1 )
206210 x = attn @ v
207211 x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
212+ x = x [:, 0 ]
213+ x = self .drop (x )
214+ if pre_logits :
215+ return x
208216 x = self .proj (x )
209- return x [:, 0 ]
217+ return x
0 commit comments