Skip to content

Commit 10a2eec

Browse files
Adapt ESA to support DeepSeek. (#335)
adapt to deepseek
1 parent 51ba639 commit 10a2eec

File tree

3 files changed

+280
-88
lines changed

3 files changed

+280
-88
lines changed

examples/offline_inference_esa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
9393
enforce_eager=True,
9494
distributed_executor_backend="mp",
9595
tensor_parallel_size=1,
96+
trust_remote_code=True,
9697
)
9798

9899
llm = LLM(**asdict(llm_args))
@@ -153,9 +154,8 @@ def get_prompt(prompt):
153154
for i in range(batch_size):
154155
line = lines[i]
155156
data = json.loads(line)
156-
context = data["context"]
157-
question = data["input"]
158-
prompts.append(get_prompt(f"{context}\n\n{question}"))
157+
prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:"""
158+
prompts.append(get_prompt(prompt))
159159

160160
sampling_params = SamplingParams(
161161
temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False

ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
From efb56ce711e3a2be60981bb5fe01d14f07dcb870 Mon Sep 17 00:00:00 2001
2-
From: flesher0813 <1208954694@qq.com>
3-
Date: Fri, 17 Oct 2025 21:01:17 +0800
4-
Subject: [PATCH] support aggregate and load failure
1+
From 67bb33e6d97dc5f55013ecfb4fb419f51e8b3c36 Mon Sep 17 00:00:00 2001
2+
From: wenxinwang <wangwenxin21@huawei.com>
3+
Date: Tue, 4 Nov 2025 17:41:40 +0800
4+
Subject: [PATCH] adapt to deepseek patch
55

6-
simplify sparse kv cache manager interface
76
---
8-
vllm/attention/layer.py | 45 +++-
7+
vllm/attention/layer.py | 49 +++-
98
.../kv_transfer/kv_connector/utils.py | 113 +++++++++
109
.../kv_transfer/kv_connector/v1/base.py | 9 +
1110
.../kv_connector/v1/multi_connector.py | 6 +
1211
.../v1/shared_storage_connector.py | 7 +-
12+
vllm/v1/attention/backends/mla/common.py | 10 +-
1313
vllm/v1/core/block_pool.py | 2 +-
1414
vllm/v1/core/kv_cache_manager.py | 7 +-
1515
vllm/v1/core/sched/output.py | 5 +
@@ -22,10 +22,10 @@ simplify sparse kv cache manager interface
2222
vllm/v1/worker/gpu_input_batch.py | 14 ++
2323
vllm/v1/worker/gpu_model_runner.py | 104 +++++++--
2424
vllm/v1/worker/gpu_worker.py | 25 +-
25-
17 files changed, 560 insertions(+), 49 deletions(-)
25+
18 files changed, 571 insertions(+), 52 deletions(-)
2626

2727
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
28-
index f0ad68b16..26cdf0445 100644
28+
index f0ad68b16..728ab99fd 100644
2929
--- a/vllm/attention/layer.py
3030
+++ b/vllm/attention/layer.py
3131
@@ -2,7 +2,6 @@
@@ -56,24 +56,26 @@ index f0ad68b16..26cdf0445 100644
5656
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
5757
return output
5858

59-
@@ -449,6 +450,7 @@ def unified_attention_with_output(
59+
@@ -449,6 +450,8 @@ def unified_attention_with_output(
6060
attn_metadata = attn_metadata[layer_name]
6161
self = forward_context.no_compile_layers[layer_name]
6262
kv_cache = self.kv_cache[forward_context.virtual_engine]
63-
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
63+
+ if not self.use_mla:
64+
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
6465
self.impl.forward(self,
6566
query,
6667
key,
67-
@@ -457,7 +459,7 @@ def unified_attention_with_output(
68+
@@ -457,7 +460,8 @@ def unified_attention_with_output(
6869
attn_metadata,
6970
output=output,
7071
output_scale=output_scale)
7172
-
72-
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
73+
+ if not self.use_mla:
74+
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
7375
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
7476

7577

76-
@@ -479,3 +481,40 @@ direct_register_custom_op(
78+
@@ -479,3 +483,42 @@ direct_register_custom_op(
7779
fake_impl=unified_attention_with_output_fake,
7880
dispatch_key=current_platform.dispatch_key,
7981
)
@@ -84,6 +86,7 @@ index f0ad68b16..26cdf0445 100644
8486
+ value: torch.Tensor,
8587
+ layer_name: str,
8688
+ forward_context: ForwardContext,
89+
+ phase: Optional[str] = None,
8790
+):
8891
+ if not has_ucm_sparse():
8992
+ return
@@ -94,7 +97,7 @@ index f0ad68b16..26cdf0445 100644
9497
+ if attn_metadata is None:
9598
+ return
9699
+
97-
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context)
100+
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase)
98101
+
99102
+def maybe_execute_sparse_attention_finished(
100103
+ query: torch.Tensor,
@@ -103,6 +106,7 @@ index f0ad68b16..26cdf0445 100644
103106
+ attn_output: torch.Tensor,
104107
+ layer_name: str,
105108
+ forward_context: ForwardContext,
109+
+ phase: Optional[str] = None,
106110
+):
107111
+ if not has_ucm_sparse():
108112
+ return
@@ -113,7 +117,7 @@ index f0ad68b16..26cdf0445 100644
113117
+ if attn_metadata is None:
114118
+ return
115119
+
116-
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context)
120+
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase)
117121
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
118122
index 5cbc8ca31..8556a979e 100644
119123
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -310,6 +314,59 @@ index 3c574d065..223106def 100644
310314

311315
def add_request(
312316
self,
317+
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
318+
index f2aaf59a4..b56f62b39 100644
319+
--- a/vllm/v1/attention/backends/mla/common.py
320+
+++ b/vllm/v1/attention/backends/mla/common.py
321+
@@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
322+
MLAAttentionImpl)
323+
from vllm.attention.backends.utils import get_mla_dims
324+
from vllm.attention.ops.merge_attn_states import merge_attn_states
325+
+from vllm.forward_context import ForwardContext, get_forward_context
326+
from vllm.attention.utils.fa_utils import get_flash_attn_version
327+
from vllm.logger import init_logger
328+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
329+
@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
330+
CommonAttentionMetadata)
331+
from vllm.v1.kv_cache_interface import AttentionSpec
332+
from vllm.v1.worker.block_table import BlockTable
333+
+from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished)
334+
335+
try:
336+
from vllm.vllm_flash_attn import flash_attn_varlen_func
337+
@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
338+
output: Optional[torch.Tensor] = None,
339+
output_scale: Optional[torch.Tensor] = None,
340+
) -> torch.Tensor:
341+
-
342+
+ forward_context: ForwardContext = get_forward_context()
343+
assert output is not None, "Output tensor must be provided."
344+
345+
if output_scale is not None:
346+
@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
347+
)
348+
349+
if has_prefill:
350+
+ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill")
351+
output[num_decode_tokens:] = self._forward_prefill(
352+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
353+
attn_metadata)
354+
-
355+
+ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill")
356+
if has_decode:
357+
assert attn_metadata.decode is not None
358+
decode_q_nope, decode_q_pe = decode_q.split(
359+
@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
360+
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
361+
# Convert from (N, B, L) to (B, N, L)
362+
decode_ql_nope = decode_ql_nope.transpose(0, 1)
363+
-
364+
+ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode")
365+
output[:num_decode_tokens] = self._forward_decode(
366+
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
367+
+ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode")
368+
369+
return output_padded
313370
diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py
314371
index d21f94727..1800665c7 100644
315372
--- a/vllm/v1/core/block_pool.py

0 commit comments

Comments
 (0)