Skip to content

Commit 9726e64

Browse files
authored
bugfix: correct attn output with base 2 or e (#28840)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
1 parent 3fd1fb0 commit 9726e64

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

vllm/attention/ops/common.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
2121
lse_idx,
2222
HEAD_DIM: tl.constexpr,
2323
N_ROUNDED: tl.constexpr,
24+
IS_BASE_E: tl.constexpr,
2425
):
2526
"""
2627
Apply the all-gathered lses to correct each local rank's attention
@@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
5556
lse_max = tl.max(lse, axis=0)
5657
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
5758
lse -= lse_max
58-
lse_exp = tl.exp(lse)
59-
lse_acc = tl.sum(lse_exp, axis=0)
60-
lse = tl.log(lse_acc)
59+
if IS_BASE_E:
60+
lse_exp = tl.exp(lse)
61+
lse_acc = tl.sum(lse_exp, axis=0)
62+
lse = tl.log(lse_acc)
63+
else:
64+
lse_exp = tl.exp2(lse)
65+
lse_acc = tl.sum(lse_exp, axis=0)
66+
lse = tl.log2(lse_acc)
6167
lse += lse_max
6268

6369
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
@@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
8187
-float("inf"),
8288
lse_finally,
8389
)
84-
factor = tl.exp(lse_finally)
90+
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
8591
output = tl.load(outputs_ptr + output_offsets)
8692
output = output * factor
8793

@@ -102,7 +108,11 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args):
102108

103109

104110
def correct_attn_out(
105-
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
111+
out: torch.Tensor,
112+
lses: torch.Tensor,
113+
cp_rank: int,
114+
ctx: CPTritonContext,
115+
is_lse_base_on_e: bool = True,
106116
) -> tuple[torch.Tensor, torch.Tensor]:
107117
"""Correct the attention output using the all-gathered lses.
108118
@@ -163,8 +173,7 @@ def correct_attn_out(
163173
l_sH,
164174
cp_rank,
165175
)
166-
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
167-
176+
const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
168177
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
169178
return out, lse
170179

@@ -174,6 +183,7 @@ def _cp_lse_common(
174183
cp_attn_lse: torch.Tensor,
175184
cp_group: GroupCoordinator,
176185
ctx: CPTritonContext | None = None,
186+
is_lse_base_on_e=True,
177187
):
178188
"""
179189
cp_attn_out: [ B, H, D ]
@@ -193,7 +203,13 @@ def _cp_lse_common(
193203

194204
cp_attn_lse = cp_attn_lse.contiguous()
195205
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
196-
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
206+
out, lse = correct_attn_out(
207+
cp_attn_out,
208+
lses,
209+
cp_group.rank_in_group,
210+
ctx,
211+
is_lse_base_on_e=is_lse_base_on_e,
212+
)
197213
return out, lse
198214

199215

@@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
203219
cp_group: GroupCoordinator,
204220
ctx: CPTritonContext | None = None,
205221
return_lse: bool = False,
222+
is_lse_base_on_e=True,
206223
):
207224
"""
208225
cp_attn_out: [ B, H, D ]
209226
cp_attn_lse: [ B, H ]
210227
"""
211-
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
228+
out, lse = _cp_lse_common(
229+
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
230+
)
212231
out = cp_group.reduce_scatter(out, dim=1)
213232

214233
if return_lse:
@@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
225244
cp_group: GroupCoordinator,
226245
ctx: CPTritonContext | None = None,
227246
return_lse: bool = False,
247+
is_lse_base_on_e=True,
228248
):
229249
"""
230250
cp_attn_out: [ B, H, D ]
231251
cp_attn_lse: [ B, H ]
232252
"""
233-
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
253+
out, lse = _cp_lse_common(
254+
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
255+
)
234256
out = cp_group.all_reduce(out)
235257

236258
if return_lse:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ def run(
249249
return_lse=True,
250250
)
251251
output_context, lse_context = cp_lse_ag_out_rs(
252-
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
252+
output_context_tmp,
253+
lse_context_tmp,
254+
get_dcp_group(),
255+
return_lse=True,
256+
is_lse_base_on_e=False,
253257
)
254258
lse_context = lse_context.transpose(0, 1).contiguous()
255259

@@ -1335,7 +1339,10 @@ def forward(
13351339
return_lse=True,
13361340
)
13371341
output[:num_decode_tokens] = cp_lse_ag_out_rs(
1338-
output_tmp, lse, get_dcp_group()
1342+
output_tmp,
1343+
lse,
1344+
get_dcp_group(),
1345+
is_lse_base_on_e=False,
13391346
)
13401347
else:
13411348
decode_wrapper.run(

vllm/v1/attention/backends/mla/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,12 @@ def forward(
20572057

20582058
# correct dcp attn_out with lse.
20592059
if self.dcp_world_size > 1:
2060-
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
2060+
attn_out = cp_lse_ag_out_rs(
2061+
attn_out,
2062+
lse,
2063+
get_dcp_group(),
2064+
is_lse_base_on_e=not self._use_fi_prefill,
2065+
)
20612066

20622067
# v_up projection
20632068
self._v_up_proj(attn_out, out=output[:num_decode_tokens])

0 commit comments

Comments
 (0)