Skip to content

Commit aae725a

Browse files
authored
[Performance] Remove redundant clone() calls in cutlass_mla (vllm-project#24891)
1 parent 73df49e commit aae725a

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,14 @@ def _sm100_cutlass_mla_decode(
210210
sm_scale,
211211
num_kv_splits,
212212
)
213-
returned_lse = lse[:, :H].contiguous(
214-
) if self.need_to_return_lse_for_decode else lse
215-
return out[:, :H].contiguous(), returned_lse
213+
214+
if H < MAX_HEADS:
215+
# Extract the subsets of the outputs
216+
returned_lse = lse[:, :H].contiguous(
217+
) if self.need_to_return_lse_for_decode else lse
218+
out = out[:, :H]
219+
220+
return out, returned_lse
216221

217222
def _sm100_forward_decode(
218223
self,
@@ -228,11 +233,6 @@ def _sm100_forward_decode(
228233
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
229234

230235
# Run MLA
231-
# Clone q_nope and q_pe to make sure strides computation is correct.
232-
# TODO: Check if we really need it
233-
q_nope = q_nope.clone()
234-
q_pe = q_pe.clone()
235-
236236
o, lse = self._sm100_cutlass_mla_decode(
237237
q_nope,
238238
q_pe,

0 commit comments

Comments
 (0)