@@ -318,8 +318,8 @@ def cloud_ai_100_exec_kv(
318318 prompts_txt_file_path : Optional [str ] = None ,
319319 device_id : Optional [List [int ]] = None ,
320320 generation_len : Optional [int ] = None ,
321- comp_ctx_lengths : Optional [List [int ]] = None ,
322- prefill_ccl_len : Optional [int ] = 1 ,
321+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
322+ comp_ctx_lengths_decode : Optional [List [ int ]] = None ,
323323 enable_debug_logs : bool = False ,
324324 stream : bool = True ,
325325 write_io_dir : Optional [str ] = None ,
@@ -384,8 +384,8 @@ def cloud_ai_100_exec_kv(
384384 qpc_path = qpc_path ,
385385 device_id = device_id ,
386386 ctx_len = ctx_len ,
387- comp_ctx_lengths = comp_ctx_lengths ,
388- prefill_ccl_len = prefill_ccl_len ,
387+ comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
388+ comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
389389 enable_debug_logs = enable_debug_logs ,
390390 write_io_dir = write_io_dir ,
391391 full_batch_size = full_batch_size ,
@@ -428,8 +428,8 @@ def __init__(
428428 qpc_path : str ,
429429 full_batch_size : Optional [int ] = None ,
430430 ctx_len : Optional [int ] = None ,
431- comp_ctx_lengths : Optional [List [int ]] = None ,
432- prefill_ccl_len : Optional [int ] = 1 ,
431+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
432+ comp_ctx_lengths_decode : Optional [List [ int ]] = None ,
433433 device_id : Optional [List [int ]] = None ,
434434 enable_debug_logs : bool = False ,
435435 write_io_dir : Optional [str ] = None ,
@@ -439,8 +439,8 @@ def __init__(
439439 sampling_params : Optional [Dict [str , Any ]] = None ,
440440 ) -> None :
441441 self ._ctx_len = ctx_len
442- self .comp_ctx_lengths = comp_ctx_lengths
443- self .prefill_ccl_len = prefill_ccl_len
442+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
443+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
444444 self ._write_io_dir = write_io_dir
445445 self .is_tlm = is_tlm
446446 self .return_pdfs = return_pdfs
@@ -799,22 +799,15 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
799799 batch_lora_ids = [self ._prompt_to_lora_id_mapping_prefill .popleft () for i in range (self .batch_size )]
800800 inputs ["lora_ids" ] = np .array (batch_lora_ids , dtype = np .int64 ).reshape (self .batch_size , 1 )
801801
802- if self .comp_ctx_lengths is not None :
803- self .list_of_comp_ctx_lengths = [np .zeros (length ) for length in self .comp_ctx_lengths ]
802+ if self .comp_ctx_lengths_prefill is not None :
803+ self .list_of_comp_ctx_lengths_prefill = [np .zeros (length ) for length in self .comp_ctx_lengths_prefill ]
804804 prefill_ccl_id = 0
805- inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [prefill_ccl_id ]
805+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
806806
807807 for i in range (num_chunks ):
808- if (i + 1 ) * self ._prefill_seq_len > self .comp_ctx_lengths [prefill_ccl_id ]:
809- prefill_ccl_id += 1
810- if prefill_ccl_id >= self .prefill_ccl_len :
811- prefill_ccl_id = (
812- (self .prefill_ccl_len - 1 )
813- if self .prefill_ccl_len != 0
814- else min (prefill_ccl_id , len (self .comp_ctx_lengths ) - 1 )
815- )
816-
817- inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [prefill_ccl_id ]
808+ if (i + 1 ) * self ._prefill_seq_len > self .comp_ctx_lengths_prefill [prefill_ccl_id ]:
809+ prefill_ccl_id = min (prefill_ccl_id + 1 , len (self .comp_ctx_lengths_prefill ) - 1 )
810+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
818811
819812 chunk_inputs = inputs .copy ()
820813 chunk_inputs ["input_ids" ] = inputs ["input_ids" ][
@@ -835,12 +828,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
835828 )
836829
837830 def initialize_ccl (self , decode_inputs ):
838- max_ccl_id = len (self .comp_ctx_lengths ) - 1
831+ self .list_of_comp_ctx_lengths_decode = [np .zeros (length ) for length in self .comp_ctx_lengths_decode ]
832+ max_ccl_id = len (self .comp_ctx_lengths_decode ) - 1
839833 max_position_id = np .max (decode_inputs ["position_ids" ])
840- ccl_id_initial = self . prefill_ccl_len
834+ ccl_id_initial = 0
841835 ccl_id = ccl_id_initial
842- for i in range (ccl_id_initial , len (self .comp_ctx_lengths )):
843- if max_position_id < self .comp_ctx_lengths [i ]:
836+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths_decode )):
837+ if max_position_id < self .comp_ctx_lengths_decode [i ]:
844838 ccl_id = i
845839 break
846840
@@ -877,9 +871,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
877871 # Prepare decode inputs inputs.
878872 decode_inputs = self .prepare_decode_inputs ()
879873
880- if self .comp_ctx_lengths is not None :
874+ if self .comp_ctx_lengths_decode is not None :
881875 ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
882- decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
876+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
883877
884878 while prompt_queue or current_decode_ongoing .any ():
885879 outputs = self ._session .run (decode_inputs )
@@ -918,19 +912,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
918912 batch_id_map [decode_batch_id ]
919913 ]
920914
921- if self .comp_ctx_lengths is not None :
915+ if self .comp_ctx_lengths_decode is not None :
922916 ###Recalculate ccl_id based on position ids###
923917 # Determine the maximum value of position_ids across all batch elements
924918 max_position_id = np .max (decode_inputs ["position_ids" ])
925919
926- # Update ccl_id and comp_ctx_lengths based on the maximum position id
920+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
927921 ccl_id_initial = self .prefill_ccl_len
928922 ccl_id = ccl_id_initial
929- for i in range (ccl_id_initial , len (self .comp_ctx_lengths )):
930- if max_position_id < self .comp_ctx_lengths [i ]:
923+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths_decode )):
924+ if max_position_id < self .comp_ctx_lengths_decode [i ]:
931925 ccl_id = i
932926 break
933- decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
927+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
934928
935929 else :
936930 current_decode_ongoing [decode_batch_id ] = False
@@ -944,11 +938,14 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
944938 if self .include_sampler :
945939 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
946940
947- if self .comp_ctx_lengths is not None :
948- # Update ccl_id and comp_ctx_lengths based on the maximum position id
949- if decode_inputs ["position_ids" ][decode_batch_id , - 1 ] >= self .comp_ctx_lengths [ccl_id ] - 1 :
941+ if self .comp_ctx_lengths_decode is not None :
942+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
943+ if (
944+ decode_inputs ["position_ids" ][decode_batch_id , - 1 ]
945+ >= self .comp_ctx_lengths_decode [ccl_id ] - 1
946+ ):
950947 ccl_id = min (ccl_id + 1 , max_ccl_id )
951- decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
948+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
952949
953950 generated_id_current_index [decode_batch_id ] += 1
954951
@@ -975,16 +972,16 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
975972 finished_sequences = decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
976973 num_token = 0
977974
978- if self .comp_ctx_lengths is not None :
975+ if self .comp_ctx_lengths_decode is not None :
979976 ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
980- decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
977+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
981978
982979 cache_index = np .max (decode_inputs ["position_ids" ])
983980 for num_token in range (1 , generation_len ):
984- if self .comp_ctx_lengths is not None :
985- if cache_index >= self .comp_ctx_lengths [ccl_id ] - 1 :
981+ if self .comp_ctx_lengths_decode is not None :
982+ if cache_index >= self .comp_ctx_lengths_decode [ccl_id ] - 1 :
986983 ccl_id = min (ccl_id + 1 , max_ccl_id )
987- decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths [ccl_id ]
984+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
988985
989986 if streamer :
990987 streamer .put (decode_inputs ["input_ids" ][0 ])
@@ -1047,8 +1044,8 @@ def __init__(
10471044 qpc_path : str ,
10481045 full_batch_size : Optional [int ] = None ,
10491046 ctx_len : Optional [int ] = None ,
1050- comp_ctx_lengths : Optional [List [int ]] = None ,
1051- prefill_ccl_len : Optional [int ] = 1 ,
1047+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1048+ comp_ctx_lengths_decode : Optional [List [ int ]] = None ,
10521049 device_id : Optional [List [int ]] = None ,
10531050 enable_debug_logs : bool = False ,
10541051 write_io_dir : Optional [str ] = None ,
@@ -1062,8 +1059,8 @@ def __init__(
10621059 qpc_path = qpc_path ,
10631060 full_batch_size = full_batch_size ,
10641061 ctx_len = ctx_len ,
1065- comp_ctx_lengths = comp_ctx_lengths ,
1066- prefill_ccl_len = prefill_ccl_len ,
1062+ comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
1063+ comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
10671064 device_id = device_id ,
10681065 enable_debug_logs = enable_debug_logs ,
10691066 write_io_dir = write_io_dir ,
@@ -1075,8 +1072,8 @@ def __init__(
10751072 self ._full_batch_size = self ._qaic_model .full_batch_size
10761073 self ._tokenizer = self ._qaic_model .tokenizer
10771074 self ._ctx_len = ctx_len
1078- self .comp_ctx_lengths = comp_ctx_lengths
1079- self .prefill_ccl_len = prefill_ccl_len
1075+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1076+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
10801077 self ._perf_metrics = None
10811078 self ._prompt_queue = None
10821079 self ._text_streamer = None
0 commit comments