@@ -324,6 +324,7 @@ def cloud_ai_100_exec_kv(
324324 stream : bool = True ,
325325 write_io_dir : Optional [str ] = None ,
326326 automation = False ,
327+ iteration : int = 1 ,
327328 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
328329 is_tlm : bool = False ,
329330 include_sampler : bool = False ,
@@ -348,6 +349,7 @@ def cloud_ai_100_exec_kv(
348349 :stream (bool): If True, enable streamer, which returns tokens one by one as the model generates them. ``Defaults to True``.
349350 :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
350351 :automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
352+ :iteration (int): Number of iterations to run the inference. ``Defaults to 1``.
351353 :prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
352354 :include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
353355 :return_pdfs (bool, default=False): Return probability distributions along with sampled
@@ -394,30 +396,34 @@ def cloud_ai_100_exec_kv(
394396 return_pdfs = return_pdfs ,
395397 sampling_params = sampling_params ,
396398 )
397- if full_batch_size is None :
398- exec_info = [
399- generate_text .generate (prompt [i : i + batch_size ], generation_len , stream , prompt_to_lora_id_mapping )
400- for i in range (0 , len (prompt ), batch_size )
401- ]
402- prefill_time = np .average ([info .perf_metrics .prefill_time for info in exec_info ])
403- decode_perf = np .average ([info .perf_metrics .decode_perf for info in exec_info ])
404- total_perf = np .average ([info .perf_metrics .total_perf for info in exec_info ])
405- total_time = np .average ([info .perf_metrics .total_time for info in exec_info ])
406- generated_texts = [info .generated_texts for info in exec_info ]
407- generated_ids = [info .generated_ids for info in exec_info ]
408-
409- exec_info = CloudAI100ExecInfo (
410- batch_size = batch_size ,
411- generated_texts = generated_texts ,
412- generated_ids = generated_ids ,
413- perf_metrics = PerfMetrics (prefill_time , decode_perf , total_perf , total_time ),
414- )
415- else :
416- exec_info = generate_text .generate (
417- prompt = prompt , generation_len = generation_len , prompt_to_lora_id_mapping = prompt_to_lora_id_mapping
418- )
419399
420- print_latency_stats_kv (prompt , exec_info = exec_info , automation = automation )
400+ for _ in range (0 , int (iteration )):
401+ if full_batch_size is None :
402+ exec_info = [
403+ generate_text .generate (prompt [i : i + batch_size ], generation_len , stream , prompt_to_lora_id_mapping )
404+ for i in range (0 , len (prompt ), batch_size )
405+ ]
406+ prefill_time = np .average ([info .perf_metrics .prefill_time for info in exec_info ])
407+ decode_perf = np .average ([info .perf_metrics .decode_perf for info in exec_info ])
408+ total_perf = np .average ([info .perf_metrics .total_perf for info in exec_info ])
409+ total_time = np .average ([info .perf_metrics .total_time for info in exec_info ])
410+ generated_texts = [info .generated_texts for info in exec_info ]
411+ generated_ids = [info .generated_ids for info in exec_info ]
412+
413+ exec_info = CloudAI100ExecInfo (
414+ batch_size = batch_size ,
415+ generated_texts = generated_texts ,
416+ generated_ids = generated_ids ,
417+ perf_metrics = PerfMetrics (prefill_time , decode_perf , total_perf , total_time ),
418+ )
419+ else :
420+ exec_info = generate_text .generate (
421+ prompt = prompt , generation_len = generation_len , prompt_to_lora_id_mapping = prompt_to_lora_id_mapping
422+ )
423+
424+ print_latency_stats_kv (prompt , exec_info = exec_info , automation = automation )
425+
426+ # TODO: Need to handle the case where exec_info if given for n iterations
421427 return exec_info
422428
423429
@@ -951,7 +957,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
951957
952958 return decode_pause_time
953959
954- def run_decode (self , decode_inputs , generation_len , streamer : Optional [transformers .TextStreamer ] = None ):
960+ def run_decode (
961+ self , decode_inputs , generation_len , automation , streamer : Optional [transformers .TextStreamer ] = None
962+ ):
955963 """
956964 Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length.
957965
@@ -1000,11 +1008,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
10001008 if self .include_sampler :
10011009 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
10021010
1003- if finished_sequences .all ():
1011+ if finished_sequences .all () and not automation :
10041012 break
10051013 return num_token
10061014
1007- def generate_decode_stream (self , decode_inputs , generation_len ):
1015+ def generate_decode_stream (self , decode_inputs , generation_len , automation ):
10081016 """
10091017 Generator method for yielding decode tokens. Executes the decoding process for a given set of inputs and a specified generation length.
10101018
@@ -1032,7 +1040,7 @@ def generate_decode_stream(self, decode_inputs, generation_len):
10321040 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ].squeeze (1 )
10331041 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
10341042
1035- if finished_sequences .all ():
1043+ if finished_sequences .all () and not automation :
10361044 break
10371045 yield decode_inputs ["input_ids" ] # yield the last token
10381046
@@ -1115,6 +1123,7 @@ def _regular_model_execution(
11151123 prompt : List [str ],
11161124 generation_len : Optional [int ] = None ,
11171125 stream : Optional [bool ] = True ,
1126+ automation : Optional [bool ] = False ,
11181127 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
11191128 ):
11201129 """
@@ -1142,7 +1151,7 @@ def _regular_model_execution(
11421151 decode_inputs = self ._qaic_model .prepare_decode_inputs ()
11431152
11441153 loop_start = perf_counter () # Start decode loop timer
1145- num_token = self ._qaic_model .run_decode (decode_inputs , generation_len , self ._text_streamer )
1154+ num_token = self ._qaic_model .run_decode (decode_inputs , generation_len , automation , self ._text_streamer )
11461155 end = perf_counter ()
11471156 generated_texts = self ._tokenizer .batch_decode (self ._qaic_model .generated_ids , skip_special_tokens = True )
11481157
@@ -1196,6 +1205,7 @@ def generate_stream_tokens(
11961205 self ,
11971206 prompt : List [str ],
11981207 generation_len : Optional [int ] = None ,
1208+ automation : Optional [bool ] = False ,
11991209 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
12001210 ):
12011211 """
@@ -1225,7 +1235,7 @@ def generate_stream_tokens(
12251235
12261236 loop_start = perf_counter () # Start decode loop timer
12271237 num_token = 0
1228- for token_id in self ._qaic_model .generate_decode_stream (decode_inputs , generation_len ):
1238+ for token_id in self ._qaic_model .generate_decode_stream (decode_inputs , generation_len , automation ):
12291239 decoded_tokens = []
12301240 for idx in range (self ._qaic_model .batch_size ):
12311241 decoded_tokens .append (self ._tokenizer .decode (token_id [idx ], skip_special_tokens = True ))
@@ -1244,6 +1254,7 @@ def generate(
12441254 prompt : List [str ],
12451255 generation_len : Optional [int ] = None ,
12461256 stream : bool = True ,
1257+ automation : Optional [bool ] = False ,
12471258 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
12481259 ):
12491260 """
@@ -1267,7 +1278,7 @@ def generate(
12671278 if stream :
12681279 print ("\n Prompt : " + prompt [0 ] + "\n Completion :" , flush = True , end = "" )
12691280 perf_metrics , generated_texts = self ._regular_model_execution (
1270- prompt , generation_len , stream , prompt_to_lora_id_mapping
1281+ prompt , generation_len , stream , automation , prompt_to_lora_id_mapping
12711282 )
12721283
12731284 if stream :
0 commit comments