3939from vllm .distributed .parallel_state import get_world_group
4040from vllm .v1 .core .kv_cache_utils import hash_request_tokens
4141from vllm .v1 .core .sched .output import SchedulerOutput
42+ from vllm .v1 .request import Request , RequestStatus
4243
4344from ucm .logger import init_logger
4445from ucm .store .base import Task
4849 from vllm .attention .backends .abstract import AttentionMetadata
4950 from vllm .forward_context import ForwardContext
5051 from vllm .v1 .core .kv_cache_manager import KVCacheBlocks
51- from vllm .v1 .request import Request
5252
5353logger = init_logger (__name__ )
5454
@@ -533,12 +533,11 @@ def get_num_new_matched_tokens(
533533 the number of tokens that can be loaded from the
534534 external KV cache beyond what is already computed.
535535 """
536- # When the request is preempt req, need to commit succeed dumped blocks
537- # to avoid duplicate invoking create/commit funcs. Only preempt reqs
538- # whose succeed_dumped_blocks is non-empty need this check.
539- if hasattr (request , "succeed_dumped_blocks" ) and request .succeed_dumped_blocks :
540- self .connector .commit (request .succeed_dumped_blocks , True )
541- request .succeed_dumped_blocks .clear ()
536+ logger .info (f"get_num_new_matched_tokens request { request .request_id } ." )
537+
538+ if request .status == RequestStatus .PREEMPTED :
539+ logger .info (f"Handle preempted request { request .request_id } ." )
540+ self .request_finished (request , [])
542541
543542 def md5 (input ) -> int :
544543 input_bytes = pickle .dumps (input , protocol = pickle .HIGHEST_PROTOCOL )
@@ -598,17 +597,6 @@ def md5(input) -> int:
598597 self ._need_load_reqs [request .request_id ] = []
599598 return num_lookup_hits * self .block_size , True
600599
601- # Create blocks for the remaining (unmatched) blocks
602- if num_lookup_hits < len (remain_hashes ):
603- remaining_hashes = remain_hashes [num_lookup_hits :]
604- create_results = self .connector .create (remaining_hashes )
605- logger .info (f"\n create_results on storage: { create_results } \n " )
606- for j , ret in enumerate (create_results ):
607- idx = num_lookup_hits + j
608- block_operations [start_position + idx ] = (
609- BlockOperation .DUMP if ret == 0 else BlockOperation .NONE
610- )
611-
612600 # When all the tokens are cached in ssd or hbm,
613601 # we need to recompute the last token. This if condition will be removed
614602 # once vLLM's scheduler provides a better solution in the future.
@@ -638,6 +626,23 @@ def update_state_after_alloc(
638626 )
639627 self ._need_load_reqs [request .request_id ] = local_block_ids
640628
629+ request_block_info = self .request_block_infos .get (request .request_id , None )
630+ if request_block_info :
631+ start_position = request_block_info .start_position
632+ block_operations = request_block_info .block_operations
633+ block_hashes = request_block_info .block_hashes
634+ start_create_pos = start_position + num_external_tokens // self .block_size
635+ remaining_hashes = block_hashes [start_create_pos :]
636+ if remaining_hashes :
637+ create_results = self .connector .create (remaining_hashes )
638+ if any (ret != 0 for ret in create_results ):
639+ logger .warning (f"\n create_results on storage: { create_results } \n " )
640+ for j , ret in enumerate (create_results ):
641+ idx = start_create_pos + j
642+ block_operations [idx ] = (
643+ BlockOperation .DUMP if ret == 0 else BlockOperation .NONE
644+ )
645+
641646 def build_connector_meta (
642647 self , scheduler_output : SchedulerOutput
643648 ) -> KVConnectorMetadata :
@@ -733,7 +738,6 @@ def request_finished(
733738 ) -> tuple [bool , Optional [dict [str , Any ]]]:
734739 block_info = self .request_block_infos .pop (request .request_id , None )
735740 if hasattr (request , "succeed_dumped_blocks" ) and request .succeed_dumped_blocks :
736- logger .debug (f"commit { request .succeed_dumped_blocks } to True." )
737741 self .connector .commit (request .succeed_dumped_blocks , True )
738742 if block_info is not None :
739743 cancel_blocks = [
@@ -744,8 +748,8 @@ def request_finished(
744748 and block_info .block_hashes [i ] not in request .succeed_dumped_blocks
745749 ]
746750 if cancel_blocks :
747- logger .warning (f"commit { cancel_blocks } to False." )
748751 self .connector .commit (cancel_blocks , False )
752+ request .succeed_dumped_blocks .clear ()
749753 return False , None
750754
751755 def _extract_blocks (
0 commit comments