@@ -59,24 +59,24 @@ def _reshape_input(self, t):
5959
6060 def forward (self , x , m : Optional [torch .Tensor ] = None ):
6161 """Run layer computation."""
62- s = x .shape
63- m = m or x
62+ b , _ , h , w = x .shape
63+ m = m if m is not None else x
6464
6565 reshaped_x = self ._reshape_input (x )
6666 reshaped_m = self ._reshape_input (m )
6767
6868 q = torch .einsum ('bnd,hkd->bnhk' , reshaped_x , self .query_proj )
6969 k = torch .einsum ('bmd,dk->bmk' , reshaped_m , self .key_proj )
7070
71- attn = torch .einsum ('bnhk,bmk->bnhm' , q , k )
71+ attn = torch .einsum ('bnhk,bmk->bnhm' , q , k ) * self . scale
7272 attn = attn .softmax (dim = - 1 )
7373 attn = self .attn_drop (attn )
7474
7575 v = torch .einsum ('bmd,dv->bmv' , reshaped_m , self .value_proj )
7676 o = torch .einsum ('bnhm,bmv->bnhv' , attn , v )
77- result = torch .einsum ('bnhv,dhv->bnd ' , o , self .out_proj )
77+ result = torch .einsum ('bnhv,dhv->bdn ' , o , self .out_proj )
7878 result = self .proj_drop (result )
79- return result .reshape (s )
79+ return result .reshape (b , - 1 , h , w )
8080
8181
8282class MultiQueryAttention2d (nn .Module ):
0 commit comments