@@ -37,8 +37,6 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
3737 o = torch .zeros_like (q )
3838 all_row_sums = torch .zeros ((* q .shape [:- 1 ], 1 ), device = device )
3939
40- q = q * scale
41-
4240 if not exists (mask ):
4341 mask = (None ,) * math .ceil (q .shape [- 2 ] / q_bucket_size )
4442 else :
@@ -63,7 +61,7 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
6361 for k_ind , (kc , vc ) in enumerate (col_splits ):
6462 k_start_index = k_ind * k_bucket_size
6563
66- attn_weights = einsum ('... i d, ... j d -> ... i j' , qc , kc )
64+ attn_weights = einsum ('... i d, ... j d -> ... i j' , qc , kc ) * scale
6765
6866 if exists (row_mask ):
6967 attn_weights .masked_fill_ (~ row_mask , max_neg_value )
@@ -129,14 +127,13 @@ def backward(ctx, do):
129127 for k_ind , (kc , vc , dkc , dvc ) in enumerate (col_splits ):
130128 k_start_index = k_ind * k_bucket_size
131129
132- qc_scaled = qc * scale
133- attn_weights = einsum ('... i d, ... j d -> ... i j' , qc_scaled , kc )
130+ attn_weights = einsum ('... i d, ... j d -> ... i j' , qc , kc ) * scale
134131
135132 if causal and q_start_index < (k_start_index + k_bucket_size - 1 ):
136133 causal_mask = torch .ones ((qc .shape [- 2 ], kc .shape [- 2 ]), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
137134 attn_weights .masked_fill_ (causal_mask , max_neg_value )
138135
139- exp_attn_weights = torch .exp (attn_weights )
136+ exp_attn_weights = torch .exp (attn_weights - scale )
140137
141138 if exists (row_mask ):
142139 exp_attn_weights .masked_fill_ (~ row_mask , 0. )
0 commit comments