Skip to content

Commit e905575

Browse files
Updated KV_pytorch and ORT inference for VLMs to incorporate image_idx (#557)
Updated the run_vlm_kv_model_on_pytorch and run_vlm_kv_model_on_ort methods to run for the latest dual QPC setup. Along with the required changes to be made in the Input Handler of VLMs. Also updated the way head_dim is calculated for past_key_value creation as certain models now provide specific head_dim. We fallback to previous method if the parameter isn't found in the config. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
1 parent 0182d95 commit e905575

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

QEfficient/utils/generate_inputs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def prepare_pytorch_inputs(self):
249249

250250
num_hidden_layers = txt_cfg.num_hidden_layers
251251
num_key_value_heads = txt_cfg.num_key_value_heads
252-
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
252+
head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads)
253253
if hasattr(txt_cfg, "cross_attention_layers"):
254254
cross_attention_layers = txt_cfg.cross_attention_layers
255255

@@ -287,7 +287,7 @@ def prepare_vlm_ort_inputs(self):
287287
txt_cfg = self.config.llm_config
288288
num_hidden_layers = txt_cfg.num_hidden_layers
289289
num_key_value_heads = txt_cfg.num_key_value_heads
290-
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
290+
head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads)
291291
if hasattr(txt_cfg, "cross_attention_layers"):
292292
cross_attention_layers = txt_cfg.cross_attention_layers
293293
vis_cfg = self.config.vision_config
@@ -298,6 +298,7 @@ def prepare_vlm_ort_inputs(self):
298298
if "attention_mask" in inputs.keys():
299299
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
300300
inputs["past_key_values"] = []
301+
inputs["image_idx"] = np.array([[0]])
301302

302303
vision_inputs = {
303304
k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"}
@@ -349,6 +350,7 @@ def update_vlm_ort_outputs(self, ort_outputs):
349350
outputs["image_features_RetainedState"] = (
350351
ort_outputs["image_features_RetainedState"] if "image_features_RetainedState" in ort_outputs else None
351352
)
353+
outputs["image_idx"] = ort_outputs["image_idx_output"]
352354
return outputs
353355

354356
def update_vlm_ort_inputs(self, inputs, ort_outputs):
@@ -414,7 +416,7 @@ def prepare_pytorch_inputs(self):
414416

415417
num_hidden_layers = txt_cfg.num_hidden_layers
416418
num_key_value_heads = txt_cfg.num_key_value_heads
417-
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
419+
head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads)
418420

419421
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
420422
inputs["past_key_values"] = []
@@ -435,7 +437,7 @@ def prepare_vlm_ort_inputs(self):
435437
txt_cfg = self.config.llm_config
436438
num_hidden_layers = txt_cfg.num_hidden_layers
437439
num_key_value_heads = txt_cfg.num_key_value_heads
438-
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
440+
head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads)
439441

440442
question = "<image>\n" + self.prompt
441443
pixel_values = self.processor.load_image(self.image, max_num=12)
@@ -449,6 +451,7 @@ def prepare_vlm_ort_inputs(self):
449451
if "attention_mask" in inputs.keys():
450452
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
451453
inputs["past_key_values"] = []
454+
inputs["image_idx"] = np.array([[0]])
452455

453456
vision_inputs = {
454457
k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"}

QEfficient/utils/run_utils.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def run_kv_model_on_pytorch(self, model):
129129

130130
generated_ids = []
131131
inputs = self.input_handler.prepare_pytorch_inputs()
132-
133132
pt_outputs = model(**inputs)
134133
for _ in range(1, self.gen_len):
135134
generated_ids.append(pt_outputs["logits"].argmax(-1).reshape(-1, 1))
@@ -291,9 +290,11 @@ def run_vlm_kv_model_on_pytorch(self, model):
291290
generation_len = self.gen_len
292291
generated_ids = torch.full((self.batch_size, generation_len), self.processor.tokenizer.pad_token_id)
293292
inputs = self.input_handler_vlm.prepare_pytorch_inputs()
293+
inputs["image_idx"] = torch.tensor([[0]])
294294

295295
outputs = model(**inputs)
296296
inputs["input_ids"] = outputs[0].argmax(2)
297+
inputs["image_idx"] = outputs[2]
297298
if "cross_attention_mask" in inputs:
298299
bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape
299300
inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64)
@@ -308,6 +309,7 @@ def run_vlm_kv_model_on_pytorch(self, model):
308309
for num_token in range(1, self.gen_len):
309310
outputs = model(**inputs)
310311
inputs["input_ids"] = outputs[0].argmax(2)
312+
inputs["image_idx"] = outputs[2]
311313
inputs["position_ids"] += 1
312314
streamer.put(inputs["input_ids"])
313315
generated_ids[:, num_token] = inputs["input_ids"].squeeze(1)
@@ -363,15 +365,23 @@ def run_vlm_kv_model_on_ort(self, model_path):
363365

364366
added_initializers, decoder_session = self.setup_ort_session(decoder_path)
365367
generated_ids = []
368+
finished_sequences = lang_inputs["input_ids"] == self.processor.tokenizer.eos_token_id
366369

367370
ort_outputs = self.run_ort_session(lang_inputs, session=decoder_session)
368371
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
372+
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
373+
lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs)
374+
369375
for _ in range(1, self.gen_len):
370-
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
371-
lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs)
376+
finished_sequences |= lang_inputs["input_ids"] == self.processor.tokenizer.eos_token_id
377+
if finished_sequences.all():
378+
break
379+
372380
ort_outputs = self.run_ort_session(lang_inputs, decoder_session)
373381
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
374-
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
382+
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
383+
lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs)
384+
375385
generated_ids = np.concatenate(generated_ids, axis=1)
376386
predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
377387
print("ORT KV_OFFLOAD Session Outputs:")
@@ -383,14 +393,22 @@ def run_vlm_kv_model_on_ort(self, model_path):
383393
added_initializers, session = self.setup_ort_session(model_path)
384394
generated_ids = []
385395
inputs = {**vision_inputs, **lang_inputs}
396+
finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id
397+
386398
ort_outputs = self.run_ort_session(inputs, session=session)
387399
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
400+
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
401+
inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs)
402+
388403
for _ in range(1, self.gen_len):
389-
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
390-
inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs)
404+
finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id
405+
if finished_sequences.all():
406+
break
391407
ort_outputs = self.run_ort_session(inputs, session)
392408
ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs)
393-
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
409+
generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1))
410+
inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs)
411+
394412
generated_ids = np.concatenate(generated_ids, axis=1)
395413
predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
396414
print("ORT Session Outputs:")

0 commit comments

Comments
 (0)