3838)
3939from vllm .distributed .parallel_state import get_world_group
4040from vllm .v1 .core .sched .output import SchedulerOutput
41+ from vllm .v1 .outputs import KVConnectorOutput
4142from vllm .v1 .request import Request , RequestStatus
4243
4344from ucm .logger import init_logger
@@ -112,6 +113,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
112113 vllm_config .parallel_config
113114 )
114115 self .head_size = vllm_config .model_config .get_head_size ()
116+ self .current_layer = 0
117+ # request id -> succeed dumped blocks
118+ self .succeed_dumped_blocks : set [str ] = set ()
115119 if (
116120 self ._vllm_config .kv_transfer_config is not None
117121 and "ucm_connector_name"
@@ -434,23 +438,20 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]:
434438 """
435439 if hasattr (self , "kv_role" ) and self .kv_role == "kv_consumer" :
436440 return
437- # request id -> succeed dumped blocks
438- success_dumped_blocks : dict [str , list [str ]] = {}
439441
440442 def wait_for_tasks ():
441443 for request_id , block_dump_tasks in self .dump_tasks .items ():
442444 for block_id , dump_tasks in block_dump_tasks .items ():
443445 if any (self .connector .wait (task ) != 0 for task in dump_tasks ):
444446 continue
445- success_dumped_blocks . setdefault ( request_id , []). append (block_id )
447+ self . succeed_dumped_blocks . add (block_id )
446448
447449 metadata = self ._get_connector_metadata ()
448450 assert isinstance (metadata , UCConnectorV1Metadata )
449451 if self .use_layerwise :
450452 wait_for_tasks ()
451453 # clear dump_tasks for all request
452454 self .dump_tasks .clear ()
453- return success_dumped_blocks if success_dumped_blocks else None
454455
455456 for request in metadata .requests :
456457 if not request .dump_blocks :
@@ -482,7 +483,6 @@ def wait_for_tasks():
482483 ).append (task )
483484 wait_for_tasks ()
484485 self .dump_tasks .clear ()
485- return success_dumped_blocks if success_dumped_blocks else None
486486
487487 def get_finished (self , finished_req_ids : set [str ]) -> tuple [set [str ], set [str ]]:
488488 """Get the finished recving and sending requests."""
@@ -507,8 +507,8 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
507507 # remove the finished requests
508508 for req_id in list (done_recving ):
509509 self ._need_load_reqs .pop (req_id , None )
510-
511- return None , done_recving
510+ done_sending , self . succeed_dumped_blocks = self . succeed_dumped_blocks , set ()
511+ return done_sending , done_recving
512512
513513 # ==============================
514514 # Scheduler-side methods
@@ -744,7 +744,7 @@ def get_requests():
744744 # When prompt tokens > max_num_batched_tokens, request of running requests may need to save
745745 for req_id , new_block_ids in get_requests ():
746746 block_info = self .request_block_infos .get (req_id )
747- if block_info :
747+ if block_info and new_block_ids :
748748 load_blocks , dump_blocks = self ._extract_blocks (
749749 new_block_ids [0 ], block_info
750750 )
@@ -759,6 +759,20 @@ def get_requests():
759759
760760 return meta
761761
762+ def update_connector_output (self , connector_output : KVConnectorOutput ):
763+ """
764+ Update KVConnector state from worker-side connectors output.
765+
766+ Args:
767+ connector_output (KVConnectorOutput): the worker-side
768+ connectors output.
769+ """
770+ done_sending = list (connector_output .finished_sending or set ())
771+ self .connector .commit (done_sending , True )
772+ self .succeed_dumped_blocks .update (done_sending )
773+ connector_output .finished_sending = set ()
774+ return
775+
762776 def request_finished (
763777 self ,
764778 request : "Request" ,
@@ -770,13 +784,12 @@ def request_finished(
770784 block_info .block_hashes [i ]
771785 for i , op in enumerate (block_info .block_operations )
772786 if op == BlockOperation .DUMP
773- and hasattr (request , "succeed_dumped_blocks" )
774- and block_info .block_hashes [i ] not in request .succeed_dumped_blocks
787+ and block_info .block_hashes [i ] not in self .succeed_dumped_blocks
775788 ]
776789 if cancel_blocks :
777790 logger .debug (f"commit { cancel_blocks } to False." )
778791 self .connector .commit (cancel_blocks , False )
779- request .succeed_dumped_blocks .clear ()
792+ self .succeed_dumped_blocks .clear ()
780793 return False , None
781794
782795 def _extract_blocks (
0 commit comments