diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..784edf950 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1160,6 +1160,8 @@ def generate( streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, + multi_frame_inference: bool = False, + num_frames: int = None, generation_len: Optional[int] = None, ) -> Union[torch.Tensor, np.ndarray]: """ @@ -1196,10 +1198,229 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + if multi_frame_inference: + if not num_frames: + raise ValueError("For multi_frames_inference num_frames is required") + return self.multi_frame_generate( + inputs=inputs, + device_ids=device_ids, + streamer=streamer, + generation_len=generation_len, + num_frames=num_frames, + ) + return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) + def multi_frame_generate( + self, + inputs: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + generation_len: int = None, + num_frames: int = None, + ): + """ + Performs generation for multimodal models with KV offloading to CPU. + + This method orchestrates the inference by running the vision encoder (if compiled) + and then iteratively running the language decoder, managing KV cache states. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Input tensors for the multimodal model. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + device_ids : List[int], optional + IDs of devices for running the QPC. Defaults to `[0]` if not specified. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + + Returns + ------- + CloudAI100ExecInfoNew + Execution information including generated IDs and performance metrics. + + Raises + ------ + TypeError + If the language model QPC is not compiled. + AssertionError + If `generation_len` is not greater than zero. + """ + if not self.lang_model.qpc_path: + raise TypeError("Please run compile API for language model first!") + + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + + pad_token_id = 1 + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + + input_len = inputs["attention_mask"].sum(1, keepdims=True) + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), pad_token_id) + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in inputs.items(): + inputs[k] = np.array(v) + + vision_inputs = { + k: v + for k, v in inputs.items() + if k + in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} + } + + vision_inputs_fp16 = {"pixel_values", "image_masks"} + vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) + + vision_start = perf_counter() + + vision_outputs = {} + if vision_inputs: + vision_size = vision_inputs["pixel_values"].shape[0] // num_frames + chunk_inputs = vision_inputs.copy() + for i in range(num_frames): + chunk_inputs["pixel_values"] = vision_inputs["pixel_values"][i * vision_size : (i + 1) * vision_size] + chunk_outputs = vision_session.run(chunk_inputs) + if i == 0: + vision_outputs = chunk_outputs + else: + vision_outputs["vision_embeds"] = np.concatenate( + (vision_outputs["vision_embeds"], chunk_outputs["vision_embeds"]), axis=1 + ) + vision_end = perf_counter() + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + + not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + if not_mllama: + lang_inputs["image_idx"] = np.array([[0]]) + + if self.vision_model.qpc_path: + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + # Prepare inputs for prefill + chunk_inputs = lang_inputs.copy() + prefill_start = perf_counter() + + # Run prefill + chunk_inputs = lang_inputs.copy() + for i in range(num_chunks): + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + ..., i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + chunk_inputs["image_idx"] = outputs["image_idx_output"] + + prefill_time = perf_counter() - prefill_start + vision_end - vision_start + # Skip inputs/outputs again + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + if not_mllama: + lang_session.skip_buffers(vision_outputs.keys()) + + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) + + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + + # Decode loop + decode_start = perf_counter() + for num_token in range(1, generation_len): + outputs = lang_session.run(lang_inputs) + + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + + decode_end = perf_counter() + if streamer: + streamer.end() + + decode_perf = (num_token - 1) / (decode_end - decode_start) + total_time = decode_end - decode_start + prefill_time + total_perf = num_token / total_time + + return CloudAI100ExecInfoNew( + batch_size=batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics( + prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time + ), + ) + def kv_offload_generate( self, inputs: List[str] = None, @@ -1355,6 +1576,8 @@ def kv_offload_generate( if x.startswith("past_") or x.endswith("_RetainedState") ] ) + if not_mllama: + lang_session.skip_buffers(vision_outputs.keys()) # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 030dd7a56..068b2b798 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -775,6 +775,7 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + num_frames: int = None, kv_offload: bool = False, **compiler_options, ): @@ -784,6 +785,8 @@ def get_specializations( logger.warning( "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config" ) + if not num_frames: + num_frames = 1 prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN channel = 3 @@ -845,6 +848,7 @@ def smart_resize( grid_width = patch_size * patch_size * temporal_patch_size * channel vision_size = grid_height // 4 grid_height = grid_height * batch_size + vision_size = vision_size * num_frames vision = [ { diff --git a/examples/qwen2.5vl/example_script.py b/examples/qwen2.5vl/example_script.py new file mode 100644 index 000000000..4e3dee09b --- /dev/null +++ b/examples/qwen2.5vl/example_script.py @@ -0,0 +1,105 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn.functional as F +import transformers +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 + +## Define Num_frames ## +num_frames = 16 + +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=16384, + num_cores=16, + num_devices=8, + height=910, + width=512, + num_frames=num_frames, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +content = [] +for i in range(1, num_frames + 1): + frame = { + "type": "image", + "image": f"./frame_{i}.jpg", + } + content.append(frame) + +content.append({"type": "text", "text": "Describe the video"}) + +messages = [ + { + "role": "user", + "content": content, + } +] + +text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + +image_inputs, video_inputs = process_vision_info(messages) + +inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", +) + +input_ids_length = inputs["input_ids"].shape[1] + +inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + +pos_ids, rope_deltas = qeff_model.model.model.get_rope_index( + inputs["input_ids"], + inputs["image_grid_thw"], + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], +) + +inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + +prefill_seq_len = 128 +num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float +padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + +inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 +) + +inputs.pop("image_grid_thw") +streamer = TextStreamer(tokenizer) +output = qeff_model.generate(inputs=inputs, generation_len=200, multi_frame_inference=True, num_franms=num_frames) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/qwen2.5vl/video_to_frames.py b/examples/qwen2.5vl/video_to_frames.py new file mode 100644 index 000000000..7ed119482 --- /dev/null +++ b/examples/qwen2.5vl/video_to_frames.py @@ -0,0 +1,42 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import cv2 + +video_path = "video_path" +output_size = (910, 512) # Replace with your desired (x, y) dimensions + +cap = cv2.VideoCapture(video_path) + +fps = cap.get(cv2.CAP_PROP_FPS) +total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) +duration = total_frames / fps + +# Calculate frame indices to extract +num_frames_to_extract = 16 +interval = int(total_frames / num_frames_to_extract) +frame_indices = [i * interval for i in range(num_frames_to_extract)] + +# === Extract and resize frames === +resized_frames = [] + +for idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if ret: + resized = cv2.resize(frame, output_size) + resized_frames.append(resized) + else: + print(f"Failed to read frame at index {idx}") + +cap.release() + +## Save frames ## +for i, frame in enumerate(resized_frames): + cv2.imwrite(f"frame_{i + 1}.jpg", frame) + +print(f"Extracted and resized {len(resized_frames)} frames to size {output_size}.")