Skip to content

Commit 606b00e

Browse files
authored
[bugfix][DCP] fix block_size of hash in DCP prefix caching (vllm-project#26296)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 720d3cd commit 606b00e

File tree

5 files changed

+12
-10
lines changed

5 files changed

+12
-10
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,7 @@ def create_scheduler_with_priority(
14111411
kv_cache_config=kv_cache_config,
14121412
log_stats=True,
14131413
structured_output_manager=StructuredOutputManager(vllm_config),
1414+
block_size=block_size,
14141415
)
14151416

14161417

tests/v1/core/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def create_scheduler(
129129
return scheduler_cls(
130130
vllm_config=vllm_config,
131131
kv_cache_config=kv_cache_config,
132+
block_size=block_size,
132133
log_stats=True,
133134
structured_output_manager=StructuredOutputManager(vllm_config),
134135
)

tests/v1/kv_connector/unit/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def create_scheduler(
138138
kv_cache_config=kv_cache_config,
139139
log_stats=True,
140140
structured_output_manager=StructuredOutputManager(vllm_config),
141+
block_size=block_size,
141142
)
142143

143144

vllm/v1/core/sched/scheduler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
vllm_config: VllmConfig,
4646
kv_cache_config: KVCacheConfig,
4747
structured_output_manager: StructuredOutputManager,
48+
block_size: int,
4849
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
4950
include_finished_set: bool = False,
5051
log_stats: bool = False,
@@ -101,15 +102,8 @@ def __init__(
101102
num_gpu_blocks = self.cache_config.num_gpu_blocks
102103
assert num_gpu_blocks is not None and num_gpu_blocks > 0
103104

104-
self.block_size = self.cache_config.block_size
105-
105+
self.block_size = block_size
106106
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
107-
# Note(hc): The scheduler’s block_size must be multiplied
108-
# by dcp_world_size, since block hashes are computed on the
109-
# original full token sequence at a granularity of
110-
# original_block_size × dcp_world_size.
111-
if self.dcp_world_size > 1:
112-
self.block_size *= self.dcp_world_size
113107

114108
# req_id -> Request
115109
self.requests: dict[str, Request] = {}

vllm/v1/engine/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,18 @@ def __init__(
142142
logger.info("Disabling chunked prefill for model without KVCache")
143143
vllm_config.scheduler_config.chunked_prefill_enabled = False
144144

145+
scheduler_block_size = (
146+
vllm_config.cache_config.block_size
147+
* vllm_config.parallel_config.decode_context_parallel_size
148+
)
149+
145150
self.scheduler: SchedulerInterface = Scheduler(
146151
vllm_config=vllm_config,
147152
kv_cache_config=kv_cache_config,
148153
structured_output_manager=self.structured_output_manager,
149154
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
150155
log_stats=self.log_stats,
156+
block_size=scheduler_block_size,
151157
)
152158
self.use_spec_decode = vllm_config.speculative_config is not None
153159
if self.scheduler.connector is not None: # type: ignore
@@ -177,14 +183,13 @@ def __init__(
177183
self.vllm_config.cache_config.enable_prefix_caching
178184
or self.scheduler.get_kv_connector() is not None
179185
):
180-
block_size = vllm_config.cache_config.block_size
181186
caching_hash_fn = get_hash_fn_by_name(
182187
vllm_config.cache_config.prefix_caching_hash_algo
183188
)
184189
init_none_hash(caching_hash_fn)
185190

186191
self.request_block_hasher = get_request_block_hasher(
187-
block_size, caching_hash_fn
192+
scheduler_block_size, caching_hash_fn
188193
)
189194

190195
self.step_fn = (

0 commit comments

Comments
 (0)