@@ -160,6 +160,7 @@ def __init__(
160160 self .dim_head = dim_head
161161 self .head_first = head_first
162162 self .scale = dim_head ** - 0.5
163+ self .fast_attn = hasattr (torch .nn .functional , 'scaled_dot_product_attention' ) # FIXME
163164
164165 self .qkv = nn .Conv2d (dim , dim_attn * 3 , 1 , bias = bias )
165166 self .rel_pos = rel_pos_cls (num_heads = self .num_heads ) if rel_pos_cls else None
@@ -175,15 +176,31 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
175176 else :
176177 q , k , v = self .qkv (x ).reshape (B , 3 , self .num_heads , self .dim_head , - 1 ).unbind (1 )
177178
178- attn = (q .transpose (- 2 , - 1 ) @ k ) * self .scale
179- if self .rel_pos is not None :
180- attn = self .rel_pos (attn )
181- elif shared_rel_pos is not None :
182- attn = attn + shared_rel_pos
183- attn = attn .softmax (dim = - 1 )
184- attn = self .attn_drop (attn )
179+ if self .fast_attn :
180+ if self .rel_pos is not None :
181+ attn_bias = self .rel_pos .get_bias ()
182+ elif shared_rel_pos is not None :
183+ attn_bias = shared_rel_pos
184+ else :
185+ attn_bias = None
186+ x = torch .nn .functional .scaled_dot_product_attention (
187+ q .transpose (- 1 , - 2 ),
188+ k .transpose (- 1 , - 2 ),
189+ v .transpose (- 1 , - 2 ),
190+ attn_mask = attn_bias ,
191+ dropout_p = self .attn_drop .p ,
192+ ).transpose (- 1 , - 2 ).reshape (B , - 1 , H , W )
193+ else :
194+ q = q * self .scale
195+ attn = q .transpose (- 2 , - 1 ) @ k
196+ if self .rel_pos is not None :
197+ attn = self .rel_pos (attn )
198+ elif shared_rel_pos is not None :
199+ attn = attn + shared_rel_pos
200+ attn = attn .softmax (dim = - 1 )
201+ attn = self .attn_drop (attn )
202+ x = (v @ attn .transpose (- 2 , - 1 )).view (B , - 1 , H , W )
185203
186- x = (v @ attn .transpose (- 2 , - 1 )).view (B , - 1 , H , W )
187204 x = self .proj (x )
188205 x = self .proj_drop (x )
189206 return x
@@ -211,6 +228,7 @@ def __init__(
211228 self .dim_head = dim_head
212229 self .head_first = head_first
213230 self .scale = dim_head ** - 0.5
231+ self .fast_attn = hasattr (torch .nn .functional , 'scaled_dot_product_attention' ) # FIXME
214232
215233 self .qkv = nn .Linear (dim , dim_attn * 3 , bias = bias )
216234 self .rel_pos = rel_pos_cls (num_heads = self .num_heads ) if rel_pos_cls else None
@@ -227,15 +245,30 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
227245 else :
228246 q , k , v = self .qkv (x ).reshape (B , - 1 , 3 , self .num_heads , self .dim_head ).transpose (1 , 3 ).unbind (2 )
229247
230- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
231- if self .rel_pos is not None :
232- attn = self .rel_pos (attn , shared_rel_pos = shared_rel_pos )
233- elif shared_rel_pos is not None :
234- attn = attn + shared_rel_pos
235- attn = attn .softmax (dim = - 1 )
236- attn = self .attn_drop (attn )
237-
238- x = (attn @ v ).transpose (1 , 2 ).reshape (restore_shape + (- 1 ,))
248+ if self .fast_attn :
249+ if self .rel_pos is not None :
250+ attn_bias = self .rel_pos .get_bias ()
251+ elif shared_rel_pos is not None :
252+ attn_bias = shared_rel_pos
253+ else :
254+ attn_bias = None
255+ x = torch .nn .functional .scaled_dot_product_attention (
256+ q , k , v ,
257+ attn_mask = attn_bias ,
258+ dropout_p = self .attn_drop .p ,
259+ )
260+ else :
261+ q = q * self .scale
262+ attn = q @ k .transpose (- 2 , - 1 )
263+ if self .rel_pos is not None :
264+ attn = self .rel_pos (attn , shared_rel_pos = shared_rel_pos )
265+ elif shared_rel_pos is not None :
266+ attn = attn + shared_rel_pos
267+ attn = attn .softmax (dim = - 1 )
268+ attn = self .attn_drop (attn )
269+ x = attn @ v
270+
271+ x = x .transpose (1 , 2 ).reshape (restore_shape + (- 1 ,))
239272 x = self .proj (x )
240273 x = self .proj_drop (x )
241274 return x
0 commit comments