From 26243cc428f41c4d69c97da48567f306977033e2 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 5 Nov 2025 09:21:36 +0000 Subject: [PATCH 1/5] Continuous Batching for VLMs Signed-off-by: Asmita Goswami --- .../models/gemma3/modeling_gemma3.py | 93 +++++++++++++----- .../models/llava/modeling_llava.py | 95 ++++++++++++++----- 2 files changed, 136 insertions(+), 52 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 20b7036fd..1dbce208b 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -592,7 +592,15 @@ def __init__(self, model): self.config = self.model.config self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_index @@ -603,7 +611,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -648,6 +660,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 @@ -667,24 +682,39 @@ def get_specializations( "ctx_len": ctx_len, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) + specializations = {} if kv_offload: @@ -692,19 +722,23 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} - pkv_dynamic_sliding_axes = {0: "batch_size", 2: "sliding_window"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} + pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} layer_switch = ( self.language_model.config.sliding_window_pattern if hasattr(self.language_model.config, "sliding_window_pattern") @@ -767,7 +801,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -806,13 +840,20 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV lang_inputs["past_key_values"] = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index e260beb05..f3c4304e8 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import Optional + import torch import torch.nn as nn import torch.utils.checkpoint @@ -16,6 +18,7 @@ from QEfficient.utils.logging_utils import logger BS = 1 +FBS = 4 NUM_CHANNEL = 3 SEQ_LEN = 592 CTX_LEN = 1024 @@ -51,7 +54,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -65,6 +76,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + batch_index=batch_index, return_dict=True, ) @@ -120,7 +132,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -145,11 +157,13 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for i in range(num_layers): lang_inputs["past_key_values"].append( ( - torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), - torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim), ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1) inputs = {} if kv_offload: @@ -167,6 +181,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_images = compiler_options.pop("max_num_images", 1) @@ -187,24 +204,40 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) + specializations = {} if kv_offload: @@ -212,9 +245,11 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -224,11 +259,19 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } dynamic_axes = {} if kv_offload: From 21b18d7febf8b521400718e26be3a41c8acdc010 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Mon, 10 Nov 2025 12:30:54 +0000 Subject: [PATCH 2/5] Added CB support for InternVL Signed-off-by: Asmita Goswami --- QEfficient/generation/embedding_handler.py | 85 +++++++++++++++- QEfficient/generation/vlm_generation.py | 8 ++ .../models/internvl/modeling_internvl.py | 89 ++++++++++++----- .../transformers/models/modeling_auto.py | 10 +- examples/internvl_CB_example.py | 98 +++++++++++++++++++ 5 files changed, 262 insertions(+), 28 deletions(-) create mode 100644 examples/internvl_CB_example.py diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 76da7afc2..f18e84179 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -12,13 +12,14 @@ operations, separating them from the main text generation logic. """ -from typing import Any, Dict, Optional, Tuple +from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple import numpy as np import requests import torch from PIL import Image -from transformers import AutoImageProcessor +from transformers import AutoImageProcessor, AutoTokenizer from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils.logging_utils import logger @@ -37,6 +38,9 @@ def __init__( qeff_model: Optional[QAICInferenceSession], vision_session: Optional[QAICInferenceSession], processor: Optional[AutoImageProcessor], + tokenizer: Optional[AutoTokenizer], + image_height: Optional[int] = None, + image_width: Optional[int] = None, config: Optional[Dict[str, Any]] = None, lang_session: Optional[QAICInferenceSession] = None, ): @@ -46,12 +50,16 @@ def __init__( Args: vision_session: QAICInferenceSession for vision model processor: AutoImageProcessor for image preprocessing + tokenizer: AutoTokenizer for text tokenization config: Configuration dictionary with vision model parameters lang_session: Optional language session for coordination (to avoid resource conflicts) """ self._qeff_model = qeff_model self._vision_session = vision_session self._processor = processor + self._tokenizer = tokenizer + self._image_height = image_height + self._image_width = image_width self._config = config or {} self._lang_session = lang_session # Store language session for coordination @@ -70,6 +78,71 @@ def is_available(self) -> bool: """ return self._vision_session is not None and self._processor is not None + def prepare_internVL_inputs(self, img_url: str, query: str) -> Dict[str, np.ndarray]: + """ + Prepare inputs for InternVL model + + Args: + image_url: URL or path to image + query: Text query to process with image + prompt = [query] + """ + if not self._tokenizer: + raise ValueError("Tokenizer is required for InternVL input preparation") + prompt = query + pixel_values = [] + num_patches_list = [] + questions = [] + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + if self._image_height and self._image_width: + image = image.resize((self._image_height, self._image_width)) + else: + logger.warning("Height and Width not specified. Using default image size for num_patches = 13.") + image = image.resize((1000, 747)) + + # preprocess the resized image + pixel_value = self._processor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + pixel_values.append(pixel_value) + + question = "\n" + prompt + questions.append(question) + + pixel_values = torch.cat(pixel_values, dim=0) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) + + inputs = self._tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_values.clone() + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: """ Download and preprocess image into model inputs @@ -323,7 +396,13 @@ def get_processed_inputs( try: ## Get vlm inputs ## - vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "internvl_chat" + ): + vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query) + else: + vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) # Handle padding for language model pad_token_id = 1 diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 2e8f04f2b..b00556ab9 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -86,6 +86,8 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, full_batch_size: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, @@ -139,6 +141,9 @@ def __init__( ) self.qeff_model = qeff_model self.processor = processor + self.tokenizer = tokenizer + self.image_height = image_height + self.image_width = image_width self._vision_qpc_path = vision_qpc_path self.device_id = device_id # Store device_id for vision components self.enable_debug_logs = enable_debug_logs # Store for vision components @@ -169,6 +174,9 @@ def _init_vision_components(self): qeff_model=self.qeff_model, vision_session=self._vision_session, processor=self.processor, + tokenizer=self.tokenizer, + image_height=self.image_height, + image_width=self.image_width, config=vision_config, lang_session=self._session, # Pass language session for coordination ) diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 38d0fe167..422d6a2b9 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -34,7 +36,15 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -55,7 +65,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -75,6 +89,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): num_patches = compiler_options.pop("num_patches", None) @@ -104,24 +121,38 @@ def get_specializations( "batched_num_patches": batch_size * num_patches, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -130,18 +161,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {1: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes @@ -173,7 +208,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -222,10 +257,13 @@ def get_dummy_inputs(self, kv_offload: bool = False): ) lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64) + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -234,6 +272,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..0e4409445 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1190,6 +1190,8 @@ def generate( device_ids: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1246,6 +1248,8 @@ def generate( device_id=device_ids, # if device_ids is not None else [0], ctx_len=ctx_len_comp, full_batch_size=fbs, + image_height=image_height, + image_width=image_width, ) # Call generate method @@ -2273,7 +2277,11 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, ) return cls( model, diff --git a/examples/internvl_CB_example.py b/examples/internvl_CB_example.py new file mode 100644 index 000000000..486f9db6c --- /dev/null +++ b/examples/internvl_CB_example.py @@ -0,0 +1,98 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.test_utils import InternProcessor + +model_id = "OpenGVLab/InternVL2_5-1B" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) +# For Testing Purpose Only +config.llm_config.num_hidden_layers = 2 +config.vision_config.num_hidden_layers = 2 + +model_hf = AutoModelForCausalLM.from_pretrained( + model_id, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, +) + +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False) +processor = InternProcessor(model_hf, tokenizer) + + +continuous_batching = True +if continuous_batching: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + trust_remote_code=True, + ) + + qeff_model.compile( + num_patches=13, # Set num_patches according to image_height and image_width, default is 13 (747 x 1000) + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + batch_size=1, + full_batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config, trust_remote_code=True + ) + + qeff_model.compile( + num_patches=13, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=10, + image_height=747, + image_width=1000, +) + +print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) From 9bf2b07cad90a468f8c811aaadad0afdc98917ee Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Mon, 10 Nov 2025 12:42:23 +0000 Subject: [PATCH 3/5] Added CB support for Mistral3 Signed-off-by: Asmita Goswami --- QEfficient/generation/embedding_handler.py | 3 + .../models/mistral3/modeling_mistral3.py | 88 ++++++++++++++----- examples/internvl_CB_example.py | 2 +- 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index f18e84179..d196a23a2 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -168,6 +168,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - else: image = Image.open(image_url) + if "mistral3" in self._qeff_model.model.config.model_type: + image = image.resize((1540, 1540)) + # Prepare conversation format conversation = [ { diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 735eec9e5..62219a71d 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -166,7 +166,15 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -179,6 +187,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds_1, position_ids=position_ids, past_key_values=past_key_values, + batch_index=batch_index, ) # Cast to int32 to avoid ONNXRT issue @@ -230,7 +239,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -270,10 +279,14 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( - config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -282,6 +295,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -299,6 +315,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): if img_size is None and hasattr(self.config.vision_config, "image_size"): @@ -323,22 +342,36 @@ def get_specializations( "vision_size": vision_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -351,7 +384,7 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -364,9 +397,18 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): "vision_embeds": {0: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } dynamic_axes = {} if kv_offload: diff --git a/examples/internvl_CB_example.py b/examples/internvl_CB_example.py index 486f9db6c..29cb9a5c4 100644 --- a/examples/internvl_CB_example.py +++ b/examples/internvl_CB_example.py @@ -45,7 +45,7 @@ num_cores=16, num_devices=4, batch_size=1, - full_batch_size=1, + full_batch_size=4, mxfp6_matmul=True, mxint8_kv_cache=True, aic_enable_depth_first=True, From c80a19a0d56c052dd3f6a4ca5be26012ba10e88c Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Tue, 11 Nov 2025 08:15:11 +0000 Subject: [PATCH 4/5] Updated test_image_text_to_text for CB tests Signed-off-by: Asmita Goswami --- QEfficient/utils/run_utils.py | 48 +++++++++++++ .../models/test_image_text_to_text_models.py | 67 ++++++++++++++++++- 2 files changed, 112 insertions(+), 3 deletions(-) diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index c54dadeac..0f82fb027 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -276,6 +276,54 @@ def __init__( self.config = config self.gen_len = max_gen_len + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + + for idx, (image, query) in enumerate(zip(images, queries)): + # Prepare conversation format for each image-query pair + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process inputs + inputs = self.processor(images=image, text=prompt, return_tensors="pt") + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Generate tokens + output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) + offset_output = output[0, inputs["input_ids"].shape[1]:] + + # Decode and print output + py_output = self.processor.tokenizer.decode(offset_output).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Query:", repr(query)) + print("Completion:", repr(py_output)) + + generated_ids.append(offset_output.numpy()) + + return generated_ids + @torch.no_grad() def run_vlm_hf_model_on_pytorch(self, model, inputs): output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index e6a145195..5d095fe87 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -38,6 +38,7 @@ # model_name, # kv_offload, # batch_size, + # full_batch_size, # prompt_len, # ctx_len, # img_size, @@ -49,6 +50,7 @@ "llava-hf/llava-1.5-7b-hf", True, 1, + 4, 784, 1024, 336, @@ -60,6 +62,7 @@ "llava-hf/llava-1.5-7b-hf", False, 1, + 4, 784, 1024, 336, @@ -72,6 +75,7 @@ # "meta-llama/Llama-4-Scout-17B-16E-Instruct", # True, # 1, + # 4, # 128, # 3072, # 336, @@ -83,6 +87,7 @@ # "meta-llama/Llama-4-Scout-17B-16E-Instruct", # False, # 1, + # 4, # 128, # 3072, # 336, @@ -94,6 +99,7 @@ "google/gemma-3-4b-it", True, 1, + 4, 128, 3072, 896, @@ -105,6 +111,7 @@ "google/gemma-3-4b-it", False, 1, + 4, 128, 3072, 896, @@ -116,6 +123,7 @@ "mistralai/Mistral-Small-3.1-24B-Instruct-2503", True, 1, + 4, 128, 4096, 1540, @@ -127,6 +135,7 @@ "mistralai/Mistral-Small-3.1-24B-Instruct-2503", False, 1, + 4, 128, 4096, 1540, @@ -138,6 +147,7 @@ "Qwen/Qwen2.5-VL-3B-Instruct", True, 1, + 4, 128, 4096, 1540, @@ -149,6 +159,7 @@ # "meta-llama/Llama-3.2-11B-Vision-Instruct", # True, # 1, + # 4, # 32, # 512, # 560, @@ -256,6 +267,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( query: str, prompt_len: int, ctx_len: int, + full_batch_size: int, max_gen_len: int = 20, batch_size: int = 1, n_layer: int = 1, @@ -341,8 +353,56 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( output = qeff_model.generate(inputs=inputs, generation_len=NEW_GENERATION_TOKENS, streamer=streamer) qpc_tokens = output.generated_ids[:, :-1] assert (pytorch_hf_tokens == qpc_tokens).all(), "Tokens don't match for pytorch HF output and QPC output" - return + # testing for CB models + if not kv_offload: # CB not yet enabled for Single QPC + return + images = [image] * full_batch_size + queries = [query] * full_batch_size + + streamer = TextStreamer(processor.tokenizer) + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries) + + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_config["model_name"], + kv_offload=kv_offload, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qeff_model.compile( + img_size=model_config["img_size"], + num_cores=16, + num_devices=num_devices, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + mxfp6_matmul=True, + enable_qnn=enable_qnn, + qnn_config=qnn_config, + ) + + print("QPC Outputs (QAIC):") + exec_info = qeff_model.generate( + tokenizer=processor.tokenizer, + processor=processor, + images=[img_url] * full_batch_size, + prompts=queries, + generation_len=max_gen_len, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), f"Tokens don't match for prompt {i} between HF and QPC output" + + return def check_molmo_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, @@ -527,10 +587,10 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.on_qaic @pytest.mark.multimodal @pytest.mark.parametrize( - "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config + "model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config ) def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( - model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer + model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer ): """ Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. @@ -547,6 +607,7 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( query=query, n_layer=n_layer, batch_size=batch_size, + full_batch_size=full_batch_size, kv_offload=kv_offload, ) From b89ea66032047e8ab891dbd471691d52e2c055af Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Tue, 11 Nov 2025 08:18:34 +0000 Subject: [PATCH 5/5] Ruff format Signed-off-by: Asmita Goswami --- QEfficient/utils/run_utils.py | 2 +- .../models/test_image_text_to_text_models.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 0f82fb027..59e3f9bf4 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -312,7 +312,7 @@ def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries): # Generate tokens output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) - offset_output = output[0, inputs["input_ids"].shape[1]:] + offset_output = output[0, inputs["input_ids"].shape[1] :] # Decode and print output py_output = self.processor.tokenizer.decode(offset_output).strip() diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index 5d095fe87..11fcf6857 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -355,7 +355,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( assert (pytorch_hf_tokens == qpc_tokens).all(), "Tokens don't match for pytorch HF output and QPC output" # testing for CB models - if not kv_offload: # CB not yet enabled for Single QPC + if not kv_offload: # CB not yet enabled for Single QPC return images = [image] * full_batch_size queries = [query] * full_batch_size @@ -400,10 +400,13 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( qpc_tokens = exec_info.generated_ids[:, :max_gen_len] for i in range(full_batch_size): - assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), f"Tokens don't match for prompt {i} between HF and QPC output" + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output" + ) return + def check_molmo_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, img_url: str, @@ -587,7 +590,8 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.on_qaic @pytest.mark.multimodal @pytest.mark.parametrize( - "model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config + "model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", + test_models_config, ) def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer