Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,8 @@ def generate(
streamer: Optional[TextStreamer] = None,
device_ids: List[int] = None,
runtime_ai100: bool = True,
multi_frame_inference: bool = False,
num_frames: int = None,
generation_len: Optional[int] = None,
) -> Union[torch.Tensor, np.ndarray]:
"""
Expand Down Expand Up @@ -1196,10 +1198,229 @@ def generate(
if not runtime_ai100:
raise NotImplementedError("PyTorch execution is not supported yet for this model!")

if multi_frame_inference:
if not num_frames:
raise ValueError("For multi_frames_inference num_frames is required")
return self.multi_frame_generate(
inputs=inputs,
device_ids=device_ids,
streamer=streamer,
generation_len=generation_len,
num_frames=num_frames,
)

return self.kv_offload_generate(
inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len
)

def multi_frame_generate(
self,
inputs: List[str] = None,
streamer: Optional[TextStreamer] = None,
device_ids: List[int] = None,
generation_len: int = None,
num_frames: int = None,
):
"""
Performs generation for multimodal models with KV offloading to CPU.

This method orchestrates the inference by running the vision encoder (if compiled)
and then iteratively running the language decoder, managing KV cache states.

Parameters
----------
inputs : Dict[str, Union[torch.Tensor, np.ndarray]]
Input tensors for the multimodal model.
streamer : TextStreamer, optional
A streamer object to display generated tokens in real-time. Default is None.
device_ids : List[int], optional
IDs of devices for running the QPC. Defaults to `[0]` if not specified.
generation_len : int, optional
The maximum number of tokens to generate. If None, it's inferred from `ctx_len`.

Returns
-------
CloudAI100ExecInfoNew
Execution information including generated IDs and performance metrics.

Raises
------
TypeError
If the language model QPC is not compiled.
AssertionError
If `generation_len` is not greater than zero.
"""
if not self.lang_model.qpc_path:
raise TypeError("Please run compile API for language model first!")

lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False)

if self.vision_model.qpc_path:
vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids)

batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path)

pad_token_id = 1

# Skip inputs/outputs
lang_session.skip_buffers(
[
x
for x in lang_session.input_names + lang_session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

# Read prompt and ctx len from session
batch_size = max(
[x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes]
+ [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]]
)

prefill_seq_len = max(
[x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes]
+ [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]]
)

input_len = inputs["attention_mask"].sum(1, keepdims=True)
input_ids_length = inputs["input_ids"].shape[1]
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len

if generation_len is None:
generation_len = ctx_len - input_len.max()
assert generation_len > 0, "generation length should be greater than zero"
generated_ids = np.full((batch_size, generation_len + 1), pad_token_id)

inputs["input_ids"] = torch.nn.functional.pad(
inputs["input_ids"],
(0, padded_len - input_ids_length),
"constant",
pad_token_id,
)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0
)
if "cross_attention_mask" in inputs:
inputs["cross_attention_mask"] = torch.nn.functional.pad(
inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length)
)

for k, v in inputs.items():
inputs[k] = np.array(v)

vision_inputs = {
k: v
for k, v in inputs.items()
if k
in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"}
}

vision_inputs_fp16 = {"pixel_values", "image_masks"}
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})

vision_start = perf_counter()

vision_outputs = {}
if vision_inputs:
vision_size = vision_inputs["pixel_values"].shape[0] // num_frames
chunk_inputs = vision_inputs.copy()
for i in range(num_frames):
chunk_inputs["pixel_values"] = vision_inputs["pixel_values"][i * vision_size : (i + 1) * vision_size]
chunk_outputs = vision_session.run(chunk_inputs)
if i == 0:
vision_outputs = chunk_outputs
else:
vision_outputs["vision_embeds"] = np.concatenate(
(vision_outputs["vision_embeds"], chunk_outputs["vision_embeds"]), axis=1
)
vision_end = perf_counter()

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

if "position_ids" in inputs:
lang_inputs["position_ids"] = inputs["position_ids"]
lang_inputs.pop("attention_mask")
else:
lang_inputs["position_ids"] = np.where(
lang_inputs.pop("attention_mask"), np.arange(padded_len), -1
) # Need to use -1 as position_ids for invalid tokens

not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama"
if not_mllama:
lang_inputs["image_idx"] = np.array([[0]])

if self.vision_model.qpc_path:
vision_session.deactivate()
lang_session.activate()

lang_session.set_buffers(vision_outputs)

# Prepare inputs for prefill
chunk_inputs = lang_inputs.copy()
prefill_start = perf_counter()

# Run prefill
chunk_inputs = lang_inputs.copy()
for i in range(num_chunks):
chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
chunk_inputs["position_ids"] = lang_inputs["position_ids"][
..., i * prefill_seq_len : (i + 1) * prefill_seq_len
]
outputs = lang_session.run(chunk_inputs)
chunk_inputs["image_idx"] = outputs["image_idx_output"]

prefill_time = perf_counter() - prefill_start + vision_end - vision_start
# Skip inputs/outputs again
lang_session.skip_buffers(
[
x
for x in lang_session.input_names + lang_session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)
if not_mllama:
lang_session.skip_buffers(vision_outputs.keys())

Copy link

@quic-xiyushi quic-xiyushi Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we skip vision embedding buffers after the prefill stage? Otherwise, those large inputs will dramatically increase the decode time. I suggested to Chulhee that he add the following lines here:

Suggested change
if not_mllama:
lang_session.skip_buffers(vision_outputs.keys())

And with this change, he was able to increase the decode rate from 13 t/s to 20 t/s.

Could you incorporate this change into the PR, and possibly apply a similar update to the kv_offload_generate function as well? Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @quic-xiyushi, This is resolved, I have pushed the changes, pls check.

# Get first token
lang_inputs["input_ids"] = outputs["logits"].argmax(2)
lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1
if "cross_attention_mask" in lang_inputs:
bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape
lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy()
generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1)

if streamer:
streamer.put(lang_inputs["input_ids"][0])

# Decode loop
decode_start = perf_counter()
for num_token in range(1, generation_len):
outputs = lang_session.run(lang_inputs)

# Prepare inputs for next iteration
lang_inputs["input_ids"] = outputs["logits"].argmax(2)
lang_inputs["position_ids"] += 1
generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1)
if streamer:
streamer.put(lang_inputs["input_ids"][0])

decode_end = perf_counter()
if streamer:
streamer.end()

decode_perf = (num_token - 1) / (decode_end - decode_start)
total_time = decode_end - decode_start + prefill_time
total_perf = num_token / total_time

return CloudAI100ExecInfoNew(
batch_size=batch_size,
generated_ids=generated_ids,
perf_metrics=PerfMetrics(
prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time
),
)

def kv_offload_generate(
self,
inputs: List[str] = None,
Expand Down Expand Up @@ -1355,6 +1576,8 @@ def kv_offload_generate(
if x.startswith("past_") or x.endswith("_RetainedState")
]
)
if not_mllama:
lang_session.skip_buffers(vision_outputs.keys())

# Get first token
lang_inputs["input_ids"] = outputs["logits"].argmax(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def get_specializations(
img_size: None,
height: int = None,
width: int = None,
num_frames: int = None,
kv_offload: bool = False,
**compiler_options,
):
Expand All @@ -784,6 +785,8 @@ def get_specializations(
logger.warning(
"Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config"
)
if not num_frames:
num_frames = 1
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
channel = 3
Expand Down Expand Up @@ -845,6 +848,7 @@ def smart_resize(
grid_width = patch_size * patch_size * temporal_patch_size * channel
vision_size = grid_height // 4
grid_height = grid_height * batch_size
vision_size = vision_size * num_frames

vision = [
{
Expand Down
105 changes: 105 additions & 0 deletions examples/qwen2.5vl/example_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import torch
import torch.nn.functional as F
import transformers
from qwen_vl_utils import process_vision_info
from transformers import AutoConfig, AutoProcessor, TextStreamer

from QEfficient import QEFFAutoModelForImageTextToText

## For AWQ model update pytorch version to 2.8.*
model_id = "Qwen/Qwen2.5-VL-32B-Instruct"
config = AutoConfig.from_pretrained(model_id)

## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model

qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
model_id, attn_implementation="eager", kv_offload=True, config=config
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

batch_size = 1

## Define Num_frames ##
num_frames = 16

## Vision + Text ##
qeff_model.compile(
batch_size=batch_size,
prefill_seq_len=128,
ctx_len=16384,
num_cores=16,
num_devices=8,
height=910,
width=512,
num_frames=num_frames,
mxfp6_matmul=True,
mxint8_kv_cache=True,
aic_enable_depth_first=True,
mos=1,
)

content = []
for i in range(1, num_frames + 1):
frame = {
"type": "image",
"image": f"./frame_{i}.jpg",
}
content.append(frame)

content.append({"type": "text", "text": "Describe the video"})

messages = [
{
"role": "user",
"content": content,
}
]

text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

image_inputs, video_inputs = process_vision_info(messages)

inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)

input_ids_length = inputs["input_ids"].shape[1]

inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1)

pos_ids, rope_deltas = qeff_model.model.model.get_rope_index(
inputs["input_ids"],
inputs["image_grid_thw"],
video_grid_thw=None,
second_per_grid_ts=None,
attention_mask=inputs["attention_mask"],
)

inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0)

prefill_seq_len = 128
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len

inputs["position_ids"] = F.pad(
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
)

inputs.pop("image_grid_thw")
streamer = TextStreamer(tokenizer)
output = qeff_model.generate(inputs=inputs, generation_len=200, multi_frame_inference=True, num_franms=num_frames)
print(output.generated_ids)
print(tokenizer.batch_decode(output.generated_ids))
print(output)
Loading
Loading