@@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
2121 lse_idx ,
2222 HEAD_DIM : tl .constexpr ,
2323 N_ROUNDED : tl .constexpr ,
24+ IS_BASE_E : tl .constexpr ,
2425):
2526 """
2627 Apply the all-gathered lses to correct each local rank's attention
@@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
5556 lse_max = tl .max (lse , axis = 0 )
5657 lse_max = tl .where (lse_max == - float ("inf" ), 0 , lse_max )
5758 lse -= lse_max
58- lse_exp = tl .exp (lse )
59- lse_acc = tl .sum (lse_exp , axis = 0 )
60- lse = tl .log (lse_acc )
59+ if IS_BASE_E :
60+ lse_exp = tl .exp (lse )
61+ lse_acc = tl .sum (lse_exp , axis = 0 )
62+ lse = tl .log (lse_acc )
63+ else :
64+ lse_exp = tl .exp2 (lse )
65+ lse_acc = tl .sum (lse_exp , axis = 0 )
66+ lse = tl .log2 (lse_acc )
6167 lse += lse_max
6268
6369 lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
@@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
8187 - float ("inf" ),
8288 lse_finally ,
8389 )
84- factor = tl .exp (lse_finally )
90+ factor = tl .exp (lse_finally ) if IS_BASE_E else tl . exp2 ( lse_finally )
8591 output = tl .load (outputs_ptr + output_offsets )
8692 output = output * factor
8793
@@ -102,7 +108,11 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args):
102108
103109
104110def correct_attn_out (
105- out : torch .Tensor , lses : torch .Tensor , cp_rank : int , ctx : CPTritonContext
111+ out : torch .Tensor ,
112+ lses : torch .Tensor ,
113+ cp_rank : int ,
114+ ctx : CPTritonContext ,
115+ is_lse_base_on_e : bool = True ,
106116) -> tuple [torch .Tensor , torch .Tensor ]:
107117 """Correct the attention output using the all-gathered lses.
108118
@@ -163,8 +173,7 @@ def correct_attn_out(
163173 l_sH ,
164174 cp_rank ,
165175 )
166- const_args = {"HEAD_DIM" : D , "N_ROUNDED" : N }
167-
176+ const_args = {"HEAD_DIM" : D , "N_ROUNDED" : N , "IS_BASE_E" : is_lse_base_on_e }
168177 ctx .call_kernel (_correct_attn_cp_out_kernel , grid , * regular_args , ** const_args )
169178 return out , lse
170179
@@ -174,6 +183,7 @@ def _cp_lse_common(
174183 cp_attn_lse : torch .Tensor ,
175184 cp_group : GroupCoordinator ,
176185 ctx : CPTritonContext | None = None ,
186+ is_lse_base_on_e = True ,
177187):
178188 """
179189 cp_attn_out: [ B, H, D ]
@@ -193,7 +203,13 @@ def _cp_lse_common(
193203
194204 cp_attn_lse = cp_attn_lse .contiguous ()
195205 lses = cp_group .all_gather (cp_attn_lse , dim = 0 ).view_as (lses )
196- out , lse = correct_attn_out (cp_attn_out , lses , cp_group .rank_in_group , ctx )
206+ out , lse = correct_attn_out (
207+ cp_attn_out ,
208+ lses ,
209+ cp_group .rank_in_group ,
210+ ctx ,
211+ is_lse_base_on_e = is_lse_base_on_e ,
212+ )
197213 return out , lse
198214
199215
@@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
203219 cp_group : GroupCoordinator ,
204220 ctx : CPTritonContext | None = None ,
205221 return_lse : bool = False ,
222+ is_lse_base_on_e = True ,
206223):
207224 """
208225 cp_attn_out: [ B, H, D ]
209226 cp_attn_lse: [ B, H ]
210227 """
211- out , lse = _cp_lse_common (cp_attn_out , cp_attn_lse , cp_group , ctx = ctx )
228+ out , lse = _cp_lse_common (
229+ cp_attn_out , cp_attn_lse , cp_group , ctx = ctx , is_lse_base_on_e = is_lse_base_on_e
230+ )
212231 out = cp_group .reduce_scatter (out , dim = 1 )
213232
214233 if return_lse :
@@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
225244 cp_group : GroupCoordinator ,
226245 ctx : CPTritonContext | None = None ,
227246 return_lse : bool = False ,
247+ is_lse_base_on_e = True ,
228248):
229249 """
230250 cp_attn_out: [ B, H, D ]
231251 cp_attn_lse: [ B, H ]
232252 """
233- out , lse = _cp_lse_common (cp_attn_out , cp_attn_lse , cp_group , ctx = ctx )
253+ out , lse = _cp_lse_common (
254+ cp_attn_out , cp_attn_lse , cp_group , ctx = ctx , is_lse_base_on_e = is_lse_base_on_e
255+ )
234256 out = cp_group .all_reduce (out )
235257
236258 if return_lse :
0 commit comments