@@ -129,7 +129,6 @@ def run_kv_model_on_pytorch(self, model):
129129
130130 generated_ids = []
131131 inputs = self .input_handler .prepare_pytorch_inputs ()
132-
133132 pt_outputs = model (** inputs )
134133 for _ in range (1 , self .gen_len ):
135134 generated_ids .append (pt_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
@@ -291,9 +290,11 @@ def run_vlm_kv_model_on_pytorch(self, model):
291290 generation_len = self .gen_len
292291 generated_ids = torch .full ((self .batch_size , generation_len ), self .processor .tokenizer .pad_token_id )
293292 inputs = self .input_handler_vlm .prepare_pytorch_inputs ()
293+ inputs ["image_idx" ] = torch .tensor ([[0 ]])
294294
295295 outputs = model (** inputs )
296296 inputs ["input_ids" ] = outputs [0 ].argmax (2 )
297+ inputs ["image_idx" ] = outputs [2 ]
297298 if "cross_attention_mask" in inputs :
298299 bs , _ , num_images , img_tiles = inputs ["cross_attention_mask" ].shape
299300 inputs ["cross_attention_mask" ] = torch .ones ((bs , 1 , num_images , img_tiles ), dtype = torch .int64 )
@@ -308,6 +309,7 @@ def run_vlm_kv_model_on_pytorch(self, model):
308309 for num_token in range (1 , self .gen_len ):
309310 outputs = model (** inputs )
310311 inputs ["input_ids" ] = outputs [0 ].argmax (2 )
312+ inputs ["image_idx" ] = outputs [2 ]
311313 inputs ["position_ids" ] += 1
312314 streamer .put (inputs ["input_ids" ])
313315 generated_ids [:, num_token ] = inputs ["input_ids" ].squeeze (1 )
@@ -363,15 +365,23 @@ def run_vlm_kv_model_on_ort(self, model_path):
363365
364366 added_initializers , decoder_session = self .setup_ort_session (decoder_path )
365367 generated_ids = []
368+ finished_sequences = lang_inputs ["input_ids" ] == self .processor .tokenizer .eos_token_id
366369
367370 ort_outputs = self .run_ort_session (lang_inputs , session = decoder_session )
368371 ort_outputs = self .input_handler_vlm .update_vlm_ort_outputs (ort_outputs )
372+ generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
373+ lang_inputs = self .input_handler_vlm .update_vlm_ort_inputs (lang_inputs , ort_outputs )
374+
369375 for _ in range (1 , self .gen_len ):
370- generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
371- lang_inputs = self .input_handler_vlm .update_vlm_ort_inputs (lang_inputs , ort_outputs )
376+ finished_sequences |= lang_inputs ["input_ids" ] == self .processor .tokenizer .eos_token_id
377+ if finished_sequences .all ():
378+ break
379+
372380 ort_outputs = self .run_ort_session (lang_inputs , decoder_session )
373381 ort_outputs = self .input_handler_vlm .update_vlm_ort_outputs (ort_outputs )
374- generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
382+ generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
383+ lang_inputs = self .input_handler_vlm .update_vlm_ort_inputs (lang_inputs , ort_outputs )
384+
375385 generated_ids = np .concatenate (generated_ids , axis = 1 )
376386 predicted_string = self .processor .tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
377387 print ("ORT KV_OFFLOAD Session Outputs:" )
@@ -383,14 +393,22 @@ def run_vlm_kv_model_on_ort(self, model_path):
383393 added_initializers , session = self .setup_ort_session (model_path )
384394 generated_ids = []
385395 inputs = {** vision_inputs , ** lang_inputs }
396+ finished_sequences = inputs ["input_ids" ] == self .processor .tokenizer .eos_token_id
397+
386398 ort_outputs = self .run_ort_session (inputs , session = session )
387399 ort_outputs = self .input_handler_vlm .update_vlm_ort_outputs (ort_outputs )
400+ generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
401+ inputs = self .input_handler_vlm .update_vlm_ort_inputs (inputs , ort_outputs )
402+
388403 for _ in range (1 , self .gen_len ):
389- generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
390- inputs = self .input_handler_vlm .update_vlm_ort_inputs (inputs , ort_outputs )
404+ finished_sequences |= inputs ["input_ids" ] == self .processor .tokenizer .eos_token_id
405+ if finished_sequences .all ():
406+ break
391407 ort_outputs = self .run_ort_session (inputs , session )
392408 ort_outputs = self .input_handler_vlm .update_vlm_ort_outputs (ort_outputs )
393- generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
409+ generated_ids .append (ort_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ))
410+ inputs = self .input_handler_vlm .update_vlm_ort_inputs (inputs , ort_outputs )
411+
394412 generated_ids = np .concatenate (generated_ids , axis = 1 )
395413 predicted_string = self .processor .tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
396414 print ("ORT Session Outputs:" )
0 commit comments