@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318 prompts_txt_file_path : Optional [str ] = None ,
319319 device_id : Optional [List [int ]] = None ,
320320 generation_len : Optional [int ] = None ,
321+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
322+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
321323 enable_debug_logs : bool = False ,
322324 stream : bool = True ,
323325 write_io_dir : Optional [str ] = None ,
@@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
384386 qpc_path = qpc_path ,
385387 device_id = device_id ,
386388 ctx_len = ctx_len ,
389+ comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
390+ comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
387391 enable_debug_logs = enable_debug_logs ,
388392 write_io_dir = write_io_dir ,
389393 full_batch_size = full_batch_size ,
@@ -430,6 +434,8 @@ def __init__(
430434 qpc_path : str ,
431435 full_batch_size : Optional [int ] = None ,
432436 ctx_len : Optional [int ] = None ,
437+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
438+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
433439 device_id : Optional [List [int ]] = None ,
434440 enable_debug_logs : bool = False ,
435441 write_io_dir : Optional [str ] = None ,
@@ -439,6 +445,8 @@ def __init__(
439445 sampling_params : Optional [Dict [str , Any ]] = None ,
440446 ) -> None :
441447 self ._ctx_len = ctx_len
448+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
449+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
442450 self ._write_io_dir = write_io_dir
443451 self .is_tlm = is_tlm
444452 self .return_pdfs = return_pdfs
@@ -797,7 +805,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
797805 batch_lora_ids = [self ._prompt_to_lora_id_mapping_prefill .popleft () for i in range (self .batch_size )]
798806 inputs ["lora_ids" ] = np .array (batch_lora_ids , dtype = np .int64 ).reshape (self .batch_size , 1 )
799807
808+ if self .comp_ctx_lengths_prefill is not None :
809+ self .list_of_comp_ctx_lengths_prefill = [np .zeros (length ) for length in self .comp_ctx_lengths_prefill ]
810+ prefill_ccl_id = 0
811+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
812+
800813 for i in range (num_chunks ):
814+ if self .comp_ctx_lengths_prefill is not None :
815+ if (i + 1 ) * self ._prefill_seq_len > self .comp_ctx_lengths_prefill [prefill_ccl_id ]:
816+ prefill_ccl_id = min (prefill_ccl_id + 1 , len (self .comp_ctx_lengths_prefill ) - 1 )
817+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
818+
801819 chunk_inputs = inputs .copy ()
802820 chunk_inputs ["input_ids" ] = inputs ["input_ids" ][
803821 :, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
@@ -816,6 +834,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
816834 generation_len ,
817835 )
818836
837+ def initialize_ccl (self , decode_inputs ):
838+ self .list_of_comp_ctx_lengths_decode = [np .zeros (length ) for length in self .comp_ctx_lengths_decode ]
839+ max_ccl_id = len (self .comp_ctx_lengths_decode ) - 1
840+ max_position_id = np .max (decode_inputs ["position_ids" ])
841+ ccl_id_initial = 0
842+ ccl_id = ccl_id_initial
843+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths_decode )):
844+ if max_position_id < self .comp_ctx_lengths_decode [i ]:
845+ ccl_id = i
846+ break
847+
848+ return ccl_id , max_ccl_id
849+
819850 def run_continuous_batching_decode (self , prompt_queue , generation_len ):
820851 """
821852 Runs continuous batching decode for the given prompt queue and generation length.
@@ -847,6 +878,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
847878 # Prepare decode inputs inputs.
848879 decode_inputs = self .prepare_decode_inputs ()
849880
881+ if self .comp_ctx_lengths_decode is not None :
882+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
883+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
884+
850885 while prompt_queue or current_decode_ongoing .any ():
851886 outputs = self ._session .run (decode_inputs )
852887
@@ -884,6 +919,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
884919 batch_id_map [decode_batch_id ]
885920 ]
886921
922+ if self .comp_ctx_lengths_decode is not None :
923+ ###Recalculate ccl_id based on position ids###
924+ # Determine the maximum value of position_ids across all batch elements
925+ max_position_id = np .max (decode_inputs ["position_ids" ])
926+
927+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
928+ ccl_id_initial = 0
929+ ccl_id = ccl_id_initial
930+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths_decode )):
931+ if max_position_id < self .comp_ctx_lengths_decode [i ]:
932+ ccl_id = i
933+ break
934+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
935+
887936 else :
888937 current_decode_ongoing [decode_batch_id ] = False
889938 else :
@@ -896,6 +945,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
896945 if self .include_sampler :
897946 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
898947
948+ if self .comp_ctx_lengths_decode is not None :
949+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
950+ if (
951+ decode_inputs ["position_ids" ][decode_batch_id , - 1 ]
952+ >= self .comp_ctx_lengths_decode [ccl_id ] - 1
953+ ):
954+ ccl_id = min (ccl_id + 1 , max_ccl_id )
955+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
956+
899957 generated_id_current_index [decode_batch_id ] += 1
900958
901959 return decode_pause_time
@@ -922,7 +980,18 @@ def run_decode(
922980 self ._session .set_buffers ({"logits" : logits_out_placeholder })
923981 finished_sequences = decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
924982 num_token = 0
983+
984+ if self .comp_ctx_lengths_decode is not None :
985+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
986+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
987+
988+ cache_index = np .max (decode_inputs ["position_ids" ])
925989 for num_token in range (1 , generation_len ):
990+ if self .comp_ctx_lengths_decode is not None :
991+ if cache_index >= self .comp_ctx_lengths_decode [ccl_id ] - 1 :
992+ ccl_id = min (ccl_id + 1 , max_ccl_id )
993+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
994+
926995 if streamer :
927996 streamer .put (decode_inputs ["input_ids" ][0 ])
928997 outputs = self ._session .run (decode_inputs )
@@ -934,6 +1003,7 @@ def run_decode(
9341003 # Prepare inputs for next iteration
9351004 decode_inputs ["input_ids" ] = self ._fetch_next_token_id (outputs )
9361005 decode_inputs ["position_ids" ][:, - 1 ] += 1
1006+ cache_index += 1
9371007 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ][:, - 1 ]
9381008 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
9391009 if self .include_sampler :
@@ -983,6 +1053,8 @@ def __init__(
9831053 qpc_path : str ,
9841054 full_batch_size : Optional [int ] = None ,
9851055 ctx_len : Optional [int ] = None ,
1056+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1057+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
9861058 device_id : Optional [List [int ]] = None ,
9871059 enable_debug_logs : bool = False ,
9881060 write_io_dir : Optional [str ] = None ,
@@ -996,6 +1068,8 @@ def __init__(
9961068 qpc_path = qpc_path ,
9971069 full_batch_size = full_batch_size ,
9981070 ctx_len = ctx_len ,
1071+ comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
1072+ comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
9991073 device_id = device_id ,
10001074 enable_debug_logs = enable_debug_logs ,
10011075 write_io_dir = write_io_dir ,
@@ -1007,6 +1081,8 @@ def __init__(
10071081 self ._full_batch_size = self ._qaic_model .full_batch_size
10081082 self ._tokenizer = self ._qaic_model .tokenizer
10091083 self ._ctx_len = ctx_len
1084+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1085+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
10101086 self ._perf_metrics = None
10111087 self ._prompt_queue = None
10121088 self ._text_streamer = None
0 commit comments