File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed
vllm/v1/attention/backends/mla Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -210,9 +210,14 @@ def _sm100_cutlass_mla_decode(
210210 sm_scale ,
211211 num_kv_splits ,
212212 )
213- returned_lse = lse [:, :H ].contiguous (
214- ) if self .need_to_return_lse_for_decode else lse
215- return out [:, :H ].contiguous (), returned_lse
213+
214+ if H < MAX_HEADS :
215+ # Extract the subsets of the outputs
216+ returned_lse = lse [:, :H ].contiguous (
217+ ) if self .need_to_return_lse_for_decode else lse
218+ out = out [:, :H ]
219+
220+ return out , returned_lse
216221
217222 def _sm100_forward_decode (
218223 self ,
@@ -228,11 +233,6 @@ def _sm100_forward_decode(
228233 self ._workspace .ensure_size (attn_metadata , self ._num_kv_splits )
229234
230235 # Run MLA
231- # Clone q_nope and q_pe to make sure strides computation is correct.
232- # TODO: Check if we really need it
233- q_nope = q_nope .clone ()
234- q_pe = q_pe .clone ()
235-
236236 o , lse = self ._sm100_cutlass_mla_decode (
237237 q_nope ,
238238 q_pe ,
You can’t perform that action at this time.
0 commit comments