@@ -322,6 +322,7 @@ def cloud_ai_100_exec_kv(
322322 stream : bool = True ,
323323 write_io_dir : Optional [str ] = None ,
324324 automation = False ,
325+ iteration : int = 1 ,
325326 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
326327 is_tlm : bool = False ,
327328 include_sampler : bool = False ,
@@ -346,6 +347,7 @@ def cloud_ai_100_exec_kv(
346347 :stream (bool): If True, enable streamer, which returns tokens one by one as the model generates them. ``Defaults to True``.
347348 :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
348349 :automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
350+ :iteration (int): Number of iterations to run the inference. ``Defaults to 1``.
349351 :prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
350352 :include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
351353 :return_pdfs (bool, default=False): Return probability distributions along with sampled
@@ -390,30 +392,34 @@ def cloud_ai_100_exec_kv(
390392 return_pdfs = return_pdfs ,
391393 sampling_params = sampling_params ,
392394 )
393- if full_batch_size is None :
394- exec_info = [
395- generate_text .generate (prompt [i : i + batch_size ], generation_len , stream , prompt_to_lora_id_mapping )
396- for i in range (0 , len (prompt ), batch_size )
397- ]
398- prefill_time = np .average ([info .perf_metrics .prefill_time for info in exec_info ])
399- decode_perf = np .average ([info .perf_metrics .decode_perf for info in exec_info ])
400- total_perf = np .average ([info .perf_metrics .total_perf for info in exec_info ])
401- total_time = np .average ([info .perf_metrics .total_time for info in exec_info ])
402- generated_texts = [info .generated_texts for info in exec_info ]
403- generated_ids = [info .generated_ids for info in exec_info ]
404-
405- exec_info = CloudAI100ExecInfo (
406- batch_size = batch_size ,
407- generated_texts = generated_texts ,
408- generated_ids = generated_ids ,
409- perf_metrics = PerfMetrics (prefill_time , decode_perf , total_perf , total_time ),
410- )
411- else :
412- exec_info = generate_text .generate (
413- prompt = prompt , generation_len = generation_len , prompt_to_lora_id_mapping = prompt_to_lora_id_mapping
414- )
415395
416- print_latency_stats_kv (prompt , exec_info = exec_info , automation = automation )
396+ for _ in range (0 , int (iteration )):
397+ if full_batch_size is None :
398+ exec_info = [
399+ generate_text .generate (prompt [i : i + batch_size ], generation_len , stream , prompt_to_lora_id_mapping )
400+ for i in range (0 , len (prompt ), batch_size )
401+ ]
402+ prefill_time = np .average ([info .perf_metrics .prefill_time for info in exec_info ])
403+ decode_perf = np .average ([info .perf_metrics .decode_perf for info in exec_info ])
404+ total_perf = np .average ([info .perf_metrics .total_perf for info in exec_info ])
405+ total_time = np .average ([info .perf_metrics .total_time for info in exec_info ])
406+ generated_texts = [info .generated_texts for info in exec_info ]
407+ generated_ids = [info .generated_ids for info in exec_info ]
408+
409+ exec_info = CloudAI100ExecInfo (
410+ batch_size = batch_size ,
411+ generated_texts = generated_texts ,
412+ generated_ids = generated_ids ,
413+ perf_metrics = PerfMetrics (prefill_time , decode_perf , total_perf , total_time ),
414+ )
415+ else :
416+ exec_info = generate_text .generate (
417+ prompt = prompt , generation_len = generation_len , prompt_to_lora_id_mapping = prompt_to_lora_id_mapping
418+ )
419+
420+ print_latency_stats_kv (prompt , exec_info = exec_info , automation = automation )
421+
422+ # TODO: Need to handle the case where exec_info if given for n iterations
417423 return exec_info
418424
419425
@@ -894,7 +900,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
894900
895901 return decode_pause_time
896902
897- def run_decode (self , decode_inputs , generation_len , streamer : Optional [transformers .TextStreamer ] = None ):
903+ def run_decode (
904+ self , decode_inputs , generation_len , automation , streamer : Optional [transformers .TextStreamer ] = None
905+ ):
898906 """
899907 Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length.
900908
@@ -931,11 +939,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
931939 if self .include_sampler :
932940 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
933941
934- if finished_sequences .all ():
942+ if finished_sequences .all () and not automation :
935943 break
936944 return num_token
937945
938- def generate_decode_stream (self , decode_inputs , generation_len ):
946+ def generate_decode_stream (self , decode_inputs , generation_len , automation ):
939947 """
940948 Generator method for yielding decode tokens. Executes the decoding process for a given set of inputs and a specified generation length.
941949
@@ -963,7 +971,7 @@ def generate_decode_stream(self, decode_inputs, generation_len):
963971 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ].squeeze (1 )
964972 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
965973
966- if finished_sequences .all ():
974+ if finished_sequences .all () and not automation :
967975 break
968976 yield decode_inputs ["input_ids" ] # yield the last token
969977
@@ -1040,6 +1048,7 @@ def _regular_model_execution(
10401048 prompt : List [str ],
10411049 generation_len : Optional [int ] = None ,
10421050 stream : Optional [bool ] = True ,
1051+ automation : Optional [bool ] = False ,
10431052 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
10441053 ):
10451054 """
@@ -1067,7 +1076,7 @@ def _regular_model_execution(
10671076 decode_inputs = self ._qaic_model .prepare_decode_inputs ()
10681077
10691078 loop_start = perf_counter () # Start decode loop timer
1070- num_token = self ._qaic_model .run_decode (decode_inputs , generation_len , self ._text_streamer )
1079+ num_token = self ._qaic_model .run_decode (decode_inputs , generation_len , automation , self ._text_streamer )
10711080 end = perf_counter ()
10721081 generated_texts = self ._tokenizer .batch_decode (self ._qaic_model .generated_ids , skip_special_tokens = True )
10731082
@@ -1121,6 +1130,7 @@ def generate_stream_tokens(
11211130 self ,
11221131 prompt : List [str ],
11231132 generation_len : Optional [int ] = None ,
1133+ automation : Optional [bool ] = False ,
11241134 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
11251135 ):
11261136 """
@@ -1150,7 +1160,7 @@ def generate_stream_tokens(
11501160
11511161 loop_start = perf_counter () # Start decode loop timer
11521162 num_token = 0
1153- for token_id in self ._qaic_model .generate_decode_stream (decode_inputs , generation_len ):
1163+ for token_id in self ._qaic_model .generate_decode_stream (decode_inputs , generation_len , automation ):
11541164 decoded_tokens = []
11551165 for idx in range (self ._qaic_model .batch_size ):
11561166 decoded_tokens .append (self ._tokenizer .decode (token_id [idx ], skip_special_tokens = True ))
@@ -1169,6 +1179,7 @@ def generate(
11691179 prompt : List [str ],
11701180 generation_len : Optional [int ] = None ,
11711181 stream : bool = True ,
1182+ automation : Optional [bool ] = False ,
11721183 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
11731184 ):
11741185 """
@@ -1192,7 +1203,7 @@ def generate(
11921203 if stream :
11931204 print ("\n Prompt : " + prompt [0 ] + "\n Completion :" , flush = True , end = "" )
11941205 perf_metrics , generated_texts = self ._regular_model_execution (
1195- prompt , generation_len , stream , prompt_to_lora_id_mapping
1206+ prompt , generation_len , stream , automation , prompt_to_lora_id_mapping
11961207 )
11971208
11981209 if stream :
0 commit comments