Skip to content

Commit 83a5650

Browse files
committed
[Fix] Adapt to vllm 0.11.0, remove finish dumping
1 parent 99b09be commit 83a5650

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

ucm/integration/vllm/uc_connector.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from vllm.distributed.parallel_state import get_world_group
4040
from vllm.v1.core.sched.output import SchedulerOutput
41+
from vllm.v1.outputs import KVConnectorOutput
4142
from vllm.v1.request import Request, RequestStatus
4243

4344
from ucm.logger import init_logger
@@ -112,6 +113,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
112113
vllm_config.parallel_config
113114
)
114115
self.head_size = vllm_config.model_config.get_head_size()
116+
self.current_layer = 0
117+
# request id -> succeed dumped blocks
118+
self.succeed_dumped_blocks: set[str] = set()
115119
if (
116120
self._vllm_config.kv_transfer_config is not None
117121
and "ucm_connector_name"
@@ -434,23 +438,20 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]:
434438
"""
435439
if hasattr(self, "kv_role") and self.kv_role == "kv_consumer":
436440
return
437-
# request id -> succeed dumped blocks
438-
success_dumped_blocks: dict[str, list[str]] = {}
439441

440442
def wait_for_tasks():
441443
for request_id, block_dump_tasks in self.dump_tasks.items():
442444
for block_id, dump_tasks in block_dump_tasks.items():
443445
if any(self.connector.wait(task) != 0 for task in dump_tasks):
444446
continue
445-
success_dumped_blocks.setdefault(request_id, []).append(block_id)
447+
self.succeed_dumped_blocks.add(block_id)
446448

447449
metadata = self._get_connector_metadata()
448450
assert isinstance(metadata, UCConnectorV1Metadata)
449451
if self.use_layerwise:
450452
wait_for_tasks()
451453
# clear dump_tasks for all request
452454
self.dump_tasks.clear()
453-
return success_dumped_blocks if success_dumped_blocks else None
454455

455456
for request in metadata.requests:
456457
if not request.dump_blocks:
@@ -482,7 +483,6 @@ def wait_for_tasks():
482483
).append(task)
483484
wait_for_tasks()
484485
self.dump_tasks.clear()
485-
return success_dumped_blocks if success_dumped_blocks else None
486486

487487
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
488488
"""Get the finished recving and sending requests."""
@@ -507,8 +507,8 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
507507
# remove the finished requests
508508
for req_id in list(done_recving):
509509
self._need_load_reqs.pop(req_id, None)
510-
511-
return None, done_recving
510+
done_sending, self.succeed_dumped_blocks = self.succeed_dumped_blocks, set()
511+
return done_sending, done_recving
512512

513513
# ==============================
514514
# Scheduler-side methods
@@ -744,7 +744,7 @@ def get_requests():
744744
# When prompt tokens > max_num_batched_tokens, request of running requests may need to save
745745
for req_id, new_block_ids in get_requests():
746746
block_info = self.request_block_infos.get(req_id)
747-
if block_info:
747+
if block_info and new_block_ids:
748748
load_blocks, dump_blocks = self._extract_blocks(
749749
new_block_ids[0], block_info
750750
)
@@ -759,6 +759,20 @@ def get_requests():
759759

760760
return meta
761761

762+
def update_connector_output(self, connector_output: KVConnectorOutput):
763+
"""
764+
Update KVConnector state from worker-side connectors output.
765+
766+
Args:
767+
connector_output (KVConnectorOutput): the worker-side
768+
connectors output.
769+
"""
770+
done_sending = list(connector_output.finished_sending or set())
771+
self.connector.commit(done_sending, True)
772+
self.succeed_dumped_blocks.update(done_sending)
773+
connector_output.finished_sending = set()
774+
return
775+
762776
def request_finished(
763777
self,
764778
request: "Request",
@@ -770,13 +784,12 @@ def request_finished(
770784
block_info.block_hashes[i]
771785
for i, op in enumerate(block_info.block_operations)
772786
if op == BlockOperation.DUMP
773-
and hasattr(request, "succeed_dumped_blocks")
774-
and block_info.block_hashes[i] not in request.succeed_dumped_blocks
787+
and block_info.block_hashes[i] not in self.succeed_dumped_blocks
775788
]
776789
if cancel_blocks:
777790
logger.debug(f"commit {cancel_blocks} to False.")
778791
self.connector.commit(cancel_blocks, False)
779-
request.succeed_dumped_blocks.clear()
792+
self.succeed_dumped_blocks.clear()
780793
return False, None
781794

782795
def _extract_blocks(

0 commit comments

Comments
 (0)