@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318 prompts_txt_file_path : Optional [str ] = None ,
319319 device_id : Optional [List [int ]] = None ,
320320 generation_len : Optional [int ] = None ,
321+ comp_ctx_lengths : Optional [List [int ]] = None ,
322+ prefill_ccl_len : Optional [int ] = 1 ,
321323 enable_debug_logs : bool = False ,
322324 stream : bool = True ,
323325 write_io_dir : Optional [str ] = None ,
@@ -382,6 +384,8 @@ def cloud_ai_100_exec_kv(
382384 qpc_path = qpc_path ,
383385 device_id = device_id ,
384386 ctx_len = ctx_len ,
387+ comp_ctx_lengths = comp_ctx_lengths ,
388+ prefill_ccl_len = prefill_ccl_len ,
385389 enable_debug_logs = enable_debug_logs ,
386390 write_io_dir = write_io_dir ,
387391 full_batch_size = full_batch_size ,
@@ -424,6 +428,8 @@ def __init__(
424428 qpc_path : str ,
425429 full_batch_size : Optional [int ] = None ,
426430 ctx_len : Optional [int ] = None ,
431+ comp_ctx_lengths : Optional [List [int ]] = None ,
432+ prefill_ccl_len : Optional [int ] = 1 ,
427433 device_id : Optional [List [int ]] = None ,
428434 enable_debug_logs : bool = False ,
429435 write_io_dir : Optional [str ] = None ,
@@ -433,6 +439,8 @@ def __init__(
433439 sampling_params : Optional [Dict [str , Any ]] = None ,
434440 ) -> None :
435441 self ._ctx_len = ctx_len
442+ self .comp_ctx_lengths = comp_ctx_lengths
443+ self .prefill_ccl_len = prefill_ccl_len
436444 self ._write_io_dir = write_io_dir
437445 self .is_tlm = is_tlm
438446 self .return_pdfs = return_pdfs
@@ -791,7 +799,23 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
791799 batch_lora_ids = [self ._prompt_to_lora_id_mapping_prefill .popleft () for i in range (self .batch_size )]
792800 inputs ["lora_ids" ] = np .array (batch_lora_ids , dtype = np .int64 ).reshape (self .batch_size , 1 )
793801
802+ if self .comp_ctx_lengths is not None :
803+ self .list_of_comp_ctx_lengths = [np .zeros (length ) for length in self .comp_ctx_lengths ]
804+ prefill_ccl_id = 0
805+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [prefill_ccl_id ]
806+
794807 for i in range (num_chunks ):
808+ if (i + 1 ) * self ._prefill_seq_len > self .comp_ctx_lengths [prefill_ccl_id ]:
809+ prefill_ccl_id += 1
810+ if prefill_ccl_id >= self .prefill_ccl_len :
811+ prefill_ccl_id = (
812+ (self .prefill_ccl_len - 1 )
813+ if self .prefill_ccl_len != 0
814+ else min (prefill_ccl_id , len (self .comp_ctx_lengths ) - 1 )
815+ )
816+
817+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [prefill_ccl_id ]
818+
795819 chunk_inputs = inputs .copy ()
796820 chunk_inputs ["input_ids" ] = inputs ["input_ids" ][
797821 :, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
@@ -810,6 +834,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
810834 generation_len ,
811835 )
812836
837+ def initialize_ccl (self , decode_inputs ):
838+ max_ccl_id = len (self .comp_ctx_lengths ) - 1
839+ max_position_id = np .max (decode_inputs ["position_ids" ])
840+ ccl_id_initial = self .prefill_ccl_len
841+ ccl_id = ccl_id_initial
842+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths )):
843+ if max_position_id < self .comp_ctx_lengths [i ]:
844+ ccl_id = i
845+ break
846+
847+ return ccl_id , max_ccl_id
848+
813849 def run_continuous_batching_decode (self , prompt_queue , generation_len ):
814850 """
815851 Runs continuous batching decode for the given prompt queue and generation length.
@@ -841,6 +877,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
841877 # Prepare decode inputs inputs.
842878 decode_inputs = self .prepare_decode_inputs ()
843879
880+ if self .comp_ctx_lengths is not None :
881+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
882+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
883+
844884 while prompt_queue or current_decode_ongoing .any ():
845885 outputs = self ._session .run (decode_inputs )
846886
@@ -878,6 +918,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
878918 batch_id_map [decode_batch_id ]
879919 ]
880920
921+ if self .comp_ctx_lengths is not None :
922+ ###Recalculate ccl_id based on position ids###
923+ # Determine the maximum value of position_ids across all batch elements
924+ max_position_id = np .max (decode_inputs ["position_ids" ])
925+
926+ # Update ccl_id and comp_ctx_lengths based on the maximum position id
927+ ccl_id_initial = self .prefill_ccl_len
928+ ccl_id = ccl_id_initial
929+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths )):
930+ if max_position_id < self .comp_ctx_lengths [i ]:
931+ ccl_id = i
932+ break
933+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
934+
881935 else :
882936 current_decode_ongoing [decode_batch_id ] = False
883937 else :
@@ -890,6 +944,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890944 if self .include_sampler :
891945 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
892946
947+ if self .comp_ctx_lengths is not None :
948+ # Update ccl_id and comp_ctx_lengths based on the maximum position id
949+ if decode_inputs ["position_ids" ][decode_batch_id , - 1 ] >= self .comp_ctx_lengths [ccl_id ] - 1 :
950+ ccl_id = min (ccl_id + 1 , max_ccl_id )
951+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
952+
893953 generated_id_current_index [decode_batch_id ] += 1
894954
895955 return decode_pause_time
@@ -914,7 +974,18 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
914974 self ._session .set_buffers ({"logits" : logits_out_placeholder })
915975 finished_sequences = decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
916976 num_token = 0
977+
978+ if self .comp_ctx_lengths is not None :
979+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
980+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
981+
982+ cache_index = np .max (decode_inputs ["position_ids" ])
917983 for num_token in range (1 , generation_len ):
984+ if self .comp_ctx_lengths is not None :
985+ if cache_index >= self .comp_ctx_lengths [ccl_id ] - 1 :
986+ ccl_id = min (ccl_id + 1 , max_ccl_id )
987+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
988+
918989 if streamer :
919990 streamer .put (decode_inputs ["input_ids" ][0 ])
920991 outputs = self ._session .run (decode_inputs )
@@ -926,6 +997,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
926997 # Prepare inputs for next iteration
927998 decode_inputs ["input_ids" ] = self ._fetch_next_token_id (outputs )
928999 decode_inputs ["position_ids" ][:, - 1 ] += 1
1000+ cache_index += 1
9291001 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ][:, - 1 ]
9301002 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
9311003 if self .include_sampler :
@@ -975,6 +1047,8 @@ def __init__(
9751047 qpc_path : str ,
9761048 full_batch_size : Optional [int ] = None ,
9771049 ctx_len : Optional [int ] = None ,
1050+ comp_ctx_lengths : Optional [List [int ]] = None ,
1051+ prefill_ccl_len : Optional [int ] = 1 ,
9781052 device_id : Optional [List [int ]] = None ,
9791053 enable_debug_logs : bool = False ,
9801054 write_io_dir : Optional [str ] = None ,
@@ -988,6 +1062,8 @@ def __init__(
9881062 qpc_path = qpc_path ,
9891063 full_batch_size = full_batch_size ,
9901064 ctx_len = ctx_len ,
1065+ comp_ctx_lengths = comp_ctx_lengths ,
1066+ prefill_ccl_len = prefill_ccl_len ,
9911067 device_id = device_id ,
9921068 enable_debug_logs = enable_debug_logs ,
9931069 write_io_dir = write_io_dir ,
@@ -999,6 +1075,8 @@ def __init__(
9991075 self ._full_batch_size = self ._qaic_model .full_batch_size
10001076 self ._tokenizer = self ._qaic_model .tokenizer
10011077 self ._ctx_len = ctx_len
1078+ self .comp_ctx_lengths = comp_ctx_lengths
1079+ self .prefill_ccl_len = prefill_ccl_len
10021080 self ._perf_metrics = None
10031081 self ._prompt_queue = None
10041082 self ._text_streamer = None
0 commit comments