@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318 prompts_txt_file_path : Optional [str ] = None ,
319319 device_id : Optional [List [int ]] = None ,
320320 generation_len : Optional [int ] = None ,
321+ comp_ctx_lengths : Optional [List [int ]] = None ,
322+ prefill_ccl_len : Optional [int ] = 1 ,
321323 enable_debug_logs : bool = False ,
322324 stream : bool = True ,
323325 write_io_dir : Optional [str ] = None ,
@@ -382,6 +384,8 @@ def cloud_ai_100_exec_kv(
382384 qpc_path = qpc_path ,
383385 device_id = device_id ,
384386 ctx_len = ctx_len ,
387+ comp_ctx_lengths = comp_ctx_lengths ,
388+ prefill_ccl_len = prefill_ccl_len ,
385389 enable_debug_logs = enable_debug_logs ,
386390 write_io_dir = write_io_dir ,
387391 full_batch_size = full_batch_size ,
@@ -424,6 +428,8 @@ def __init__(
424428 qpc_path : str ,
425429 full_batch_size : Optional [int ] = None ,
426430 ctx_len : Optional [int ] = None ,
431+ comp_ctx_lengths : Optional [List [int ]] = None ,
432+ prefill_ccl_len : Optional [int ] = 1 ,
427433 device_id : Optional [List [int ]] = None ,
428434 enable_debug_logs : bool = False ,
429435 write_io_dir : Optional [str ] = None ,
@@ -433,6 +439,8 @@ def __init__(
433439 sampling_params : Optional [Dict [str , Any ]] = None ,
434440 ) -> None :
435441 self ._ctx_len = ctx_len
442+ self .comp_ctx_lengths = comp_ctx_lengths
443+ self .prefill_ccl_len = prefill_ccl_len
436444 self ._write_io_dir = write_io_dir
437445 self .is_tlm = is_tlm
438446 self .return_pdfs = return_pdfs
@@ -791,7 +799,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
791799 batch_lora_ids = [self ._prompt_to_lora_id_mapping_prefill .popleft () for i in range (self .batch_size )]
792800 inputs ["lora_ids" ] = np .array (batch_lora_ids , dtype = np .int64 ).reshape (self .batch_size , 1 )
793801
802+ if self .comp_ctx_lengths is not None :
803+ self .list_of_comp_ctx_lengths = [np .zeros (length ) for length in self .comp_ctx_lengths ]
804+ prefill_ccl_id = 0
805+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [prefill_ccl_id ]
806+
794807 for i in range (num_chunks ):
808+ if (i + 1 ) * self ._prefill_seq_len > self .comp_ctx_lengths [prefill_ccl_id ]:
809+ prefill_ccl_id += 1
810+ if prefill_ccl_id >= self .prefill_ccl_len :
811+ prefill_ccl_id = self .prefill_ccl_len - 1
812+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [prefill_ccl_id ]
813+
795814 chunk_inputs = inputs .copy ()
796815 chunk_inputs ["input_ids" ] = inputs ["input_ids" ][
797816 :, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
@@ -810,6 +829,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
810829 generation_len ,
811830 )
812831
832+ def initialize_ccl (self , decode_inputs ):
833+ max_ccl_id = len (self .comp_ctx_lengths ) - 1
834+ max_position_id = np .max (decode_inputs ["position_ids" ])
835+ ccl_id_initial = self .prefill_ccl_len
836+ ccl_id = ccl_id_initial
837+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths )):
838+ if max_position_id < self .comp_ctx_lengths [i ]:
839+ ccl_id = i
840+ break
841+
842+ print (f"Decode CCL: { self .comp_ctx_lengths [ccl_id ]} " )
843+ return ccl_id , max_ccl_id
844+
813845 def run_continuous_batching_decode (self , prompt_queue , generation_len ):
814846 """
815847 Runs continuous batching decode for the given prompt queue and generation length.
@@ -841,6 +873,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
841873 # Prepare decode inputs inputs.
842874 decode_inputs = self .prepare_decode_inputs ()
843875
876+ if self .comp_ctx_lengths is not None :
877+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
878+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
879+
844880 while prompt_queue or current_decode_ongoing .any ():
845881 outputs = self ._session .run (decode_inputs )
846882
@@ -878,6 +914,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
878914 batch_id_map [decode_batch_id ]
879915 ]
880916
917+ if self .comp_ctx_lengths is not None :
918+ ###Recalculate ccl_id based on position ids###
919+ # Determine the maximum value of position_ids across all batch elements
920+ max_position_id = np .max (decode_inputs ["position_ids" ])
921+
922+ # Update ccl_id and comp_ctx_lengths based on the maximum position id
923+ ccl_id_initial = self .prefill_ccl_len
924+ ccl_id = ccl_id_initial
925+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths )):
926+ if max_position_id < self .comp_ctx_lengths [i ]:
927+ ccl_id = i
928+ break
929+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
930+
881931 else :
882932 current_decode_ongoing [decode_batch_id ] = False
883933 else :
@@ -890,6 +940,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890940 if self .include_sampler :
891941 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
892942
943+ if self .comp_ctx_lengths is not None :
944+ # Update ccl_id and comp_ctx_lengths based on the maximum position id
945+ if decode_inputs ["position_ids" ][decode_batch_id , - 1 ] >= self .comp_ctx_lengths [ccl_id ] - 1 :
946+ ccl_id = min (ccl_id + 1 , max_ccl_id )
947+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
948+
893949 generated_id_current_index [decode_batch_id ] += 1
894950
895951 return decode_pause_time
@@ -914,7 +970,18 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
914970 self ._session .set_buffers ({"logits" : logits_out_placeholder })
915971 finished_sequences = decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
916972 num_token = 0
973+
974+ if self .comp_ctx_lengths is not None :
975+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
976+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
977+
978+ cache_index = np .max (decode_inputs ["position_ids" ])
917979 for num_token in range (1 , generation_len ):
980+ if self .comp_ctx_lengths is not None :
981+ if cache_index >= self .comp_ctx_lengths [ccl_id ] - 1 :
982+ ccl_id = min (ccl_id + 1 , max_ccl_id )
983+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
984+
918985 if streamer :
919986 streamer .put (decode_inputs ["input_ids" ][0 ])
920987 outputs = self ._session .run (decode_inputs )
@@ -926,6 +993,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
926993 # Prepare inputs for next iteration
927994 decode_inputs ["input_ids" ] = self ._fetch_next_token_id (outputs )
928995 decode_inputs ["position_ids" ][:, - 1 ] += 1
996+ cache_index += 1
929997 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ][:, - 1 ]
930998 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
931999 if self .include_sampler :
@@ -975,6 +1043,8 @@ def __init__(
9751043 qpc_path : str ,
9761044 full_batch_size : Optional [int ] = None ,
9771045 ctx_len : Optional [int ] = None ,
1046+ comp_ctx_lengths : Optional [List [int ]] = None ,
1047+ prefill_ccl_len : Optional [int ] = 1 ,
9781048 device_id : Optional [List [int ]] = None ,
9791049 enable_debug_logs : bool = False ,
9801050 write_io_dir : Optional [str ] = None ,
@@ -988,6 +1058,8 @@ def __init__(
9881058 qpc_path = qpc_path ,
9891059 full_batch_size = full_batch_size ,
9901060 ctx_len = ctx_len ,
1061+ comp_ctx_lengths = comp_ctx_lengths ,
1062+ prefill_ccl_len = prefill_ccl_len ,
9911063 device_id = device_id ,
9921064 enable_debug_logs = enable_debug_logs ,
9931065 write_io_dir = write_io_dir ,
@@ -999,6 +1071,8 @@ def __init__(
9991071 self ._full_batch_size = self ._qaic_model .full_batch_size
10001072 self ._tokenizer = self ._qaic_model .tokenizer
10011073 self ._ctx_len = ctx_len
1074+ self .comp_ctx_lengths = comp_ctx_lengths
1075+ self .prefill_ccl_len = prefill_ccl_len
10021076 self ._perf_metrics = None
10031077 self ._prompt_queue = None
10041078 self ._text_streamer = None
0 commit comments