diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 33c6f5588..3d47945d6 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -11,6 +11,7 @@ import QEfficient.utils.model_registery # noqa: F401 from QEfficient.utils import custom_format_warning from QEfficient.utils.logging_utils import logger +from QEfficient.utils.patches import apply_torch_patches, is_patched # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before @@ -22,6 +23,10 @@ # custom warning for the better logging experience warnings.formatwarning = custom_format_warning +# Apply patches +# TODO: Find a better way to do this, this is temp. fix. +apply_torch_patches() + def check_qaic_sdk(): """Check if QAIC SDK is installed""" @@ -70,6 +75,10 @@ def check_qaic_sdk(): "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", + "apply_torch_patches", + "is_patched", + "apply_torch_patches", + "is_patched", ] else: diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..2637498fb 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -18,10 +18,13 @@ import onnx import torch -from QEfficient.base.onnx_transforms import OnnxTransform +from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile +from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import ( constants, create_json, @@ -243,7 +246,13 @@ def _export( input_names.append(param) try: + # Initialize the registry with your custom ops + CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm) + CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter) + CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather) + decoder_layer_classes = get_decoder_layer_classes_for_export(self.model) export_kwargs = {} if export_kwargs is None else export_kwargs + torch.onnx.export( self.model, (example_inputs,), @@ -252,15 +261,18 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, + export_modules_as_functions=decoder_layer_classes, + do_constant_folding=True, **export_kwargs, ) logger.info("PyTorch export successful") _ = self._offload_model_weights(offload_pt_weights) - model = onnx.load(tmp_onnx_path, load_external_data=False) + transform_kwargs = { "onnx_base_dir": str(tmp_onnx_dir), + "temp_onnx_path": tmp_onnx_path, "model_name": self.model_name, } if onnx_transform_kwargs is not None: @@ -273,7 +285,6 @@ def _export( onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names())) ) logger.info("ONNX transforms applied") - onnx.save(model, onnx_path) logger.info("Transformed ONNX saved") diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 61b5c00f6..cdf1913ab 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,9 +5,12 @@ # # ---------------------------------------------------------------------------- -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np +import onnx +import onnxslim +import torch from onnx import ModelProto, external_data_helper, numpy_helper @@ -99,3 +102,165 @@ def apply( current_file_size = tsize external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") return model, transformed + + +class OnnxSlimTransform(OnnxTransform): + """ + Applies onnx-slim transformations on the given ONNX graph. + """ + + @classmethod + def apply( + cls, + model: ModelProto, + *, + onnx_base_dir: Optional[str] = None, + **kwargs, + ) -> Tuple[ModelProto, bool]: + """ + :param enable_onnx_slim_transform: If True, applies onnx-slim transformations. + :param temp_onnx_path: Path to save the slimmed ONNX model. + """ + transformed = False + onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) + temp_onnx_path = kwargs.get("temp_onnx_path", None) + if not temp_onnx_path: + err_str = "temp_onnx_path is required for onnx-slim transform." + raise RuntimeError(err_str) + if onnx_slim_transform: + transformed = True + slimmed_model = onnxslim.slim(model) + onnx.save(slimmed_model, temp_onnx_path) + return slimmed_model, transformed + return model, transformed + + +class CustomOpTransform(OnnxTransform): + """ + Transform to register custom operations and add their function protos to the ONNX model. + """ + + # Registry of custom operations + _custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func) + + @classmethod + def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any): + """Register a custom operation.""" + cls._custom_ops[op_name] = (func_class, onnxscript_func) + + @classmethod + def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: + """ + Apply custom op registration and add function protos to the model. + + :param model: The ONNX model to transform + :param opset_version: ONNX opset version for symbolic registration + :returns: Transformed model and success flag + """ + transformed = False + + # Register all custom op symbolic functions with torch.onnx + for op_name, (func_class, _) in cls._custom_ops.items(): + if hasattr(func_class, "symbolic"): + torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) + + # Add function protos for custom ops that are used in the model + used_protos = cls._get_function_protos_for_model(model) + + for proto in used_protos: + # Check if proto already exists to avoid duplicates + proto_name = proto.name + if not any(func.name == proto_name for func in model.functions): + model.functions.append(proto) + transformed = True + + return model, transformed + + @classmethod + def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]: + """Get function protos for custom ops that are actually used in the model.""" + used_protos = [] + + # Get all node op_types in the model + used_op_types = set() + for node in model.graph.node: + used_op_types.add(node.op_type) + + # Also check function calls + for func in model.functions: + for node in func.node: + used_op_types.add(node.op_type) + + # Check which custom ops are actually used + for op_name, (func_class, onnxscript_func) in cls._custom_ops.items(): + # Check if the custom op is referenced in the model + if cls._is_custom_op_used(model, op_name, used_op_types): + proto = onnxscript_func.to_function_proto() + used_protos.append(proto) + + return used_protos + + @classmethod + def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool: + """Check if a custom op is used in the model.""" + # Check if the op_name appears in node op_types + if op_name in used_op_types: + return True + + # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm") + custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" + if custom_op_pattern in used_op_types: + return True + + # Heuristic checks based on op type + if "RMSNorm" in op_name: + # Check if any RMSNorm-related ops are present + return any("RMSNorm" in op_type for op_type in used_op_types) + + if "Ctx" in op_name: + # Check if Gather/Scatter operations are present (indicating KV cache usage) + return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) + + return False + + +class RenameFunctionOutputsTransform(OnnxTransform): + """ + Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns. + """ + + @classmethod + def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: + """ + Rename function outputs in decoder layer nodes. + + :param model: The ONNX model to transform + :returns: Transformed model and boolean indicating whether transform was applied + """ + graph = model.graph + op_type_to_func_map = {func.name: func for func in model.functions} + decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] + transformed = False + model_graph_outputs = [val.name for val in model.graph.output] + layer_index = 0 + for node in graph.node: + if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): + func = op_type_to_func_map.get(node.op_type) + if func is None: + continue + + for i, out_name in enumerate(func.output): + if "_InternalRetainedState" in out_name: + transformed = True + tmp = node.output[i] + if "key" in out_name: + new_name = f"past_key.{layer_index}_RetainedState" + elif "value" in out_name: + new_name = f"past_value.{layer_index}_RetainedState" + # new_name = func.output[i].replace("Internal", "") + node.output[i] = new_name + # Update graph output name if it exists + if tmp in model_graph_outputs: + model.graph.output[model_graph_outputs.index(tmp)].name = new_name + layer_index = layer_index + 1 + return model, transformed diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index a20fc4cb3..e503a057f 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class SplitGateUpWeightsTransform(PytorchTransform): """ - split fused Gate+Up weights and copy into the model + Split fused Gate+Up weights and copy into the model. + Handles both standard MoE models and GptOss models. For every transformer layer inside `model`: - • expects .experts.gate_up_proj in the *source* `sd` - • copies halves into - .experts.gate_proj <-- Gate [E,H,I] - .experts.up_proj <-- Up [E,H,I] + • expects .experts.gate_up_proj in the *source* `sd` + • copies halves into + .experts.gate_proj <-- Gate [E,H,I] + .experts.up_proj <-- Up [E,H,I] + + Handles both interleaved weights (GptOss) and concatenated weights (standard MoE). + Also handles bias terms when present. """ @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ - if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: return model, transformed model_tmp = model.language_model if hasattr(model, "language_model") else model - num_layers = len(model_tmp.model.layers) delete_fused_key = True sd = model_tmp.state_dict() + for layer_idx in range(num_layers): + # Determine if this is a GptOss model or standard MoE model + is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp") + # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.feed_forward.experts." + if is_gpt_oss: + prefix = f"model.layers.{layer_idx}.mlp.experts." + experts = model_tmp.model.layers[layer_idx].mlp.experts + else: + prefix = f"model.layers.{layer_idx}.feed_forward.experts." + experts = model_tmp.model.layers[layer_idx].feed_forward.experts fused_key = prefix + "gate_up_proj" gate_key = prefix + "gate_proj" up_key = prefix + "up_proj" - # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- - fused = sd[fused_key] # [E, H, 2I] (no .weight here) + # Check if we have bias terms (GptOss case) + has_bias = fused_key + "_bias" in sd + if has_bias: + fused_bias_key = fused_key + "_bias" + gate_bias_key = gate_key + "_bias" + up_bias_key = up_key + "_bias" + + # ---- split weights based on model type ---------------------- + fused = sd[fused_key] # [E, H, 2I] E, H, two_I = fused.shape - ffn_dim = two_I // 2 - gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].feed_forward.experts + if is_gpt_oss: + # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] + gate = fused[..., ::2] # [E, H, I] - even indices + up = fused[..., 1::2] # [E, H, I] - odd indices + else: + # For standard MoE, gate/up are concatenated: [gate, up] + ffn_dim = two_I // 2 + gate, up = fused.split(ffn_dim, dim=-1) # views – no copy + + # Copy weights to model experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) + # Handle bias if present + if has_bias: + fused_bias = sd[fused_bias_key] # [E, 2I] + + if is_gpt_oss: + gate_bias = fused_bias[..., ::2] # [E, I] - even indices + up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + else: + ffn_dim = fused_bias.shape[-1] // 2 + gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1) + + experts.gate_proj_bias.data.copy_(gate_bias) + experts.up_proj_bias.data.copy_(up_bias) + # ---- update the state-dict so load_state_dict sees the right keys sd[gate_key] = gate sd[up_key] = up + if has_bias: + sd[gate_bias_key] = gate_bias + sd[up_bias_key] = up_bias + + # Delete fused keys if delete_fused_key: del sd[fused_key] + if has_bias: + del sd[fused_bias_key] - logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") + logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") transformed = True if hasattr(model, "language_model"): model.language_model = model_tmp else: model = model_tmp + return model, transformed -VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"} +# Keep the existing list of supported models +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"} diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 8519d824c..5068c174e 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -90,8 +90,10 @@ def __init__( self.program = qaicrt.Program(self.context, None, qpc, prog_properties) if self.program.load() != qaicrt.QStatus.QS_SUCCESS: raise RuntimeError("Failed to load program") + self.is_active = False if activate: self.activate() + self.is_active = True # Create input qbuffers and buf_dims self.qbuffers = [qaicrt.QBuffer(bytes(binding.size)) for binding in self.bindings] self.buf_dims = qaicrt.BufferDimensionsVecRef( @@ -108,15 +110,17 @@ def output_names(self) -> List[str]: def activate(self): """Activate qpc""" - - self.program.activate() - self.execObj = qaicrt.ExecObj(self.context, self.program) + if not self.is_active: + self.program.activate() + self.execObj = qaicrt.ExecObj(self.context, self.program) + self.is_active = True def deactivate(self): """Deactivate qpc""" - - del self.execObj - self.program.deactivate() + if self.is_active: + del self.execObj + self.program.deactivate() + self.is_active = False def set_buffers(self, buffers: Dict[str, np.ndarray]): """ diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py new file mode 100644 index 000000000..76da7afc2 --- /dev/null +++ b/QEfficient/generation/embedding_handler.py @@ -0,0 +1,367 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Vision Handler for Vision-Language Models + +This module provides the VisionHandler class that encapsulates all vision model +operations, separating them from the main text generation logic. +""" + +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import requests +import torch +from PIL import Image +from transformers import AutoImageProcessor + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class VisionHandler: + """ + Handles all vision model operations for vision-language models. + + This class encapsulates vision preprocessing, inference, and output handling, + providing a clean separation between vision and language processing. + """ + + def __init__( + self, + qeff_model: Optional[QAICInferenceSession], + vision_session: Optional[QAICInferenceSession], + processor: Optional[AutoImageProcessor], + config: Optional[Dict[str, Any]] = None, + lang_session: Optional[QAICInferenceSession] = None, + ): + """ + Initialize vision handler + + Args: + vision_session: QAICInferenceSession for vision model + processor: AutoImageProcessor for image preprocessing + config: Configuration dictionary with vision model parameters + lang_session: Optional language session for coordination (to avoid resource conflicts) + """ + self._qeff_model = qeff_model + self._vision_session = vision_session + self._processor = processor + self._config = config or {} + self._lang_session = lang_session # Store language session for coordination + + # Cache for vision output shapes + self._vision_output_shapes = None + + if self._vision_session and not self._processor: + logger.warning("Vision session provided but no processor. Vision functionality may be limited.") + + def is_available(self) -> bool: + """ + Check if vision processing is available + + Returns: + True if both vision session and processor are available + """ + return self._vision_session is not None and self._processor is not None + + def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: + """ + Download and preprocess image into model inputs + + Args: + image_url: URL or path to image + query: Text query to process with image + + Returns: + Dictionary of vision model inputs + + Raises: + ValueError: If vision handler is not properly initialized + RuntimeError: If image processing fails + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") + + try: + # Download image + if image_url.startswith(("http://", "https://")): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + + # Prepare conversation format + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process image and text + inputs = self._processor(images=image, text=prompt, return_tensors="pt") + + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen2_5_vl" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) + + # Convert to float32 if needed + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + + except Exception as e: + raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") + + def run_vision_inference(self, vision_inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Execute vision model inference with session coordination + + Args: + vision_inputs: Preprocessed vision inputs + + Returns: + Vision embeddings and metadata + + Raises: + ValueError: If vision session is not available + RuntimeError: If inference fails + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + lang_was_active = False + try: + # Coordinate with language session to avoid resource conflicts + if self._lang_session and self._lang_session.is_active: + logger.debug("Deactivating language session before vision inference") + self._lang_session.deactivate() + lang_was_active = True + + # Activate vision session + logger.debug("Activating vision session for inference") + self._vision_session.activate() + + # Run inference + vision_outputs = self._vision_session.run(vision_inputs) + + # Deactivate vision session + logger.debug("Deactivating vision session after inference") + self._vision_session.deactivate() + + # Reactivate language session if it was active before + if lang_was_active and self._lang_session: + logger.debug("Reactivating language session after vision inference") + self._lang_session.activate() + + return vision_outputs + + except Exception as e: + # Ensure proper cleanup on error + if self._vision_session: + try: + self._vision_session.deactivate() + except Exception: + logger.warning("Deactivating vision session failed") + + # Restore language session if needed + if lang_was_active and self._lang_session: + try: + self._lang_session.activate() + except Exception: + logger.warning("Deactivating language session failed") + + raise RuntimeError(f"Vision inference failed: {str(e)}") + + def get_vision_output_shapes(self) -> Dict[str, Tuple[int, ...]]: + """ + Get vision output dimensions from config or session + + Returns: + Dictionary mapping output names to shapes + """ + if self._vision_output_shapes is not None: + return self._vision_output_shapes + + # Try to get from config first + if self._config and "vision_output_shapes" in self._config: + self._vision_output_shapes = self._config["vision_output_shapes"] + return self._vision_output_shapes + + # Try to derive from vision session + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + self._vision_output_shapes = shapes + return shapes + except Exception as e: + logger.warning(f"Could not derive vision output shapes from session: {e}") + + # Fallback to default shapes (these were hard-coded in original implementation) + default_shapes = { + "vision_embeds": (2448, 5120) # This should be derived from model config + } + + logger.warning("Using default vision output shapes. Consider providing shapes in config.") + self._vision_output_shapes = default_shapes + return default_shapes + + def setup_vision_buffers(self): + """ + Configure vision model output buffers + + Raises: + ValueError: If vision session is not available + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + try: + shapes = self.get_vision_output_shapes() + + # Set up output buffers + buffers = {} + for output_name, shape in shapes.items(): + # Create placeholder with appropriate dtype + if "vision_embeds" in output_name: + buffers[output_name] = np.zeros(shape, dtype=np.float16) + else: + buffers[output_name] = np.zeros(shape, dtype=np.float32) + + self._vision_session.set_buffers(buffers) + + except Exception as e: + raise RuntimeError(f"Failed to setup vision buffers: {str(e)}") + + def prepare_complete_vision_language_inputs( + self, image_url: str, query: str + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Complete pipeline: prepare inputs and run vision inference + + Args: + image_url: URL or path to image + query: Text query + + Returns: + Tuple of (vision_inputs, vision_outputs) + """ + # Prepare vision inputs + vision_inputs = self.prepare_vision_inputs(image_url, query) + + # Setup buffers + self.setup_vision_buffers() + + # Run vision inference + vision_outputs = self.run_vision_inference(vision_inputs) + + return vision_inputs, vision_outputs + + def get_processed_inputs( + self, image_url: str, query: str, prefill_seq_len: int + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Process vision inputs and prepare language model inputs + + Args: + image_url: URL or path to image + query: Text query + padded_len: Padded sequence length for language model + + Returns: + Tuple of (language_inputs, vision_outputs) + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized") + + try: + ## Get vlm inputs ## + vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) + + # Handle padding for language model + pad_token_id = 1 + input_ids_length = lang_inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) + padded_len = num_chunks * prefill_seq_len + + lang_inputs["input_ids"] = torch.nn.functional.pad( + lang_inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + lang_inputs["attention_mask"] = torch.nn.functional.pad( + lang_inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + + if "cross_attention_mask" in lang_inputs: + lang_inputs["cross_attention_mask"] = torch.nn.functional.pad( + lang_inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in lang_inputs.items(): + lang_inputs[k] = np.array(v) + + vision_outputs = {} + if vision_inputs: + self.setup_vision_buffers() + vision_outputs = self.run_vision_inference(vision_inputs) + + if "position_ids" in lang_inputs: + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + lang_inputs["image_idx"] = np.array([[0]]) + + return lang_inputs, vision_outputs, num_chunks + + except Exception as e: + raise RuntimeError(f"Failed to process vision-language inputs: {str(e)}") diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..e96908824 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -437,15 +437,19 @@ def __init__( include_sampler: bool = False, return_pdfs: bool = False, sampling_params: Optional[Dict[str, Any]] = None, + activate: bool = True, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs self.sampling_params = sampling_params + self._qpc_path = qpc_path # Store qpc_path for later use # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + self._session = QAICInferenceSession( + qpc_path, device_id, activate=activate, enable_debug_logs=enable_debug_logs + ) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( @@ -778,6 +782,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.is_tlm: inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: @@ -808,6 +813,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) + if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) return ( diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py new file mode 100644 index 000000000..2e8f04f2b --- /dev/null +++ b/QEfficient/generation/vlm_generation.py @@ -0,0 +1,784 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +This module provides the VisionLanguageGeneration class that inherits from +QEffTextGenerationBase, enabling all advanced text generation features while +maintaining full API compatibility with the original VisionLanguageGeneration. + +Key enhancements: +- Continuous batching support for vision models +- Advanced streaming capabilities +- On-device sampling support +- LoRA adapter support +- Better performance metrics +""" + +from collections import deque +from time import perf_counter +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.embedding_handler import VisionHandler +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfo, + PerfMetrics, + QEffTextGenerationBase, + TextGeneration, + calculate_latency, + write_io_files, +) +from QEfficient.utils import LRUCache +from QEfficient.utils.logging_utils import logger + + +class VisionLanguageGeneration(QEffTextGenerationBase): + """ + Enhanced vision-language generation class inheriting from QEffTextGenerationBase. + + This class maintains full API compatibility with VisionLanguageGeneration while + adding advanced features like continuous batching, streaming, and sampling. + + Example: + >>> # Drop-in replacement for VisionLanguageGeneration + >>> vlm = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0] + ... ) + >>> result = vlm.generate( + ... images=["image1.jpg"], + ... prompts=["Describe this image"], + ... generation_len=512 + ... ) + + >>> # Enhanced usage with new features + >>> vlm_enhanced = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0], + ... full_batch_size=8, # Enable continuous batching + ... include_sampler=True, # Enable on-device sampling + ... sampling_params=sampling_config + ... ) + """ + + def __init__( + self, + qeff_model, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + processor: AutoImageProcessor, + lang_qpc_path: str, + vision_qpc_path: str, + device_id: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + enable_debug_logs: bool = False, + write_io_dir: Optional[str] = None, + full_batch_size: Optional[int] = None, + is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize vision-language generation with enhanced capabilities + + Args: + qeff_model: QEff model instance + tokenizer: Text tokenizer + processor: Image processor + lang_qpc_path: Path to language model QPC + vision_qpc_path: Path to vision encoder QPC + device_id: Device IDs for execution (default: [0]) + ctx_len: Context length + enable_debug_logs: Enable debug logging + write_io_dir: Directory for I/O file writing + full_batch_size: Enable continuous batching (new feature) + is_tlm: Target language model flag + include_sampler: Enable on-device sampling (new feature) + return_pdfs: Return probability distributions + sampling_params: Sampling parameters for on-device sampling + """ + # Validate required parameters + if not lang_qpc_path: + raise TypeError("lang_qpc_path is required") + if not vision_qpc_path: + raise TypeError("vision_qpc_path is required") + + # Initialize base class with language QPC + # Pass activate=False to prevent premature activation before vision components are ready + super().__init__( + tokenizer=tokenizer, + qpc_path=lang_qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, + activate=False, # vision components need to be initialized first + ) + + # Vision-specific initialization + self.is_qwen2_5_vl = ( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" + ) + self.qeff_model = qeff_model + self.processor = processor + self._vision_qpc_path = vision_qpc_path + self.device_id = device_id # Store device_id for vision components + self.enable_debug_logs = enable_debug_logs # Store for vision components + self._vision_outputs_cache = LRUCache(max_size=100) # LRU cache for vision outputs + self._vision_cache = {} # Cache for vision outputs across batches + self._init_vision_components() + + # Now that vision components are initialized, activate the text session + self._session.activate() + + logger.info( + f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " + f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " + f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " + f"sampling={'enabled' if include_sampler else 'disabled'}" + ) + + def _init_vision_components(self): + """Initialize vision-specific components""" + # Vision session (separate from base class language session) + self._vision_session = QAICInferenceSession( + self._vision_qpc_path, self.device_id, activate=False, enable_debug_logs=self.enable_debug_logs + ) + + # Vision handler with language session coordination + vision_config = self._get_vision_config() + self._vision_handler = VisionHandler( + qeff_model=self.qeff_model, + vision_session=self._vision_session, + processor=self.processor, + config=vision_config, + lang_session=self._session, # Pass language session for coordination + ) + + # Setup vision buffer skipping + self._setup_vision_buffer_skipping() + + def _get_vision_config(self) -> Dict[str, Any]: + """ + Derive vision config from session + + Returns: + Dictionary with vision configuration + """ + config = {} + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + config["vision_output_shapes"] = shapes + except Exception as e: + logger.warning(f"Could not derive vision config from session: {e}") + + return config + + def _setup_vision_buffer_skipping(self): + """Skip KV cache and retained state buffers for vision session""" + # Pre-compute skip buffers + self._vision_skip_buffers = [ + x + for x in self._vision_session.input_names + self._vision_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + self._vision_session.skip_buffers(self._vision_skip_buffers) + + # Pre-compute language skip buffers + self._lang_skip_buffers = [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + + def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs in the prompt queue and updates the decode input. + + Method iterates over the full batch size and for each decode batch ID, it pops the next prompt from the queue. It then runs prefill for the next prompt and updates the decode input with the outputs. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + + """ + for decode_batch_id in range(self.full_batch_size): + next_prompt = prompt_queue.popleft() + + # run prefill for num_chunks + outputs, position_ids, generation_len = self.run_prefill( + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + ) + + if self.is_qwen2_5_vl: + _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) + else: + _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + + def _execute_chunked_prefill( + self, + lang_inputs: Dict[str, np.ndarray], + num_chunks: int, + decode_batch_id: Optional[np.ndarray] = None, + prefill_logit_bs: int = 1, + ) -> Dict[str, np.ndarray]: + """ + Execute chunked prefill with language inputs + + Args: + lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc. + num_chunks: Number of chunks to process + decode_batch_id: Batch ID for continuous batching (optional) + prefill_logit_bs: Batch size for prefill logits + + Returns: + Final prefill outputs + """ + # Set output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) + + # Skip buffers for dual-QPC coordination + self._session.skip_buffers(self._lang_skip_buffers) + + # Run chunked prefill + outputs = None + chunk_image_idx = None + + for i in range(num_chunks): + input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] + position_ids_slice = lang_inputs["position_ids"][ + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + + chunk_inputs = { + "input_ids": input_ids_slice, + "position_ids": position_ids_slice, + "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), + } + + if decode_batch_id is not None: + chunk_inputs["batch_index"] = decode_batch_id + + if "cross_attention_mask" in lang_inputs: + chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + + outputs = self._session.run(chunk_inputs) + + if "image_idx_output" in outputs: + chunk_image_idx = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Prepare decode-time cross_attention_mask + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) + else: + self._decode_cross_attention_mask = None + + return outputs + + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + """ + Override base class prefill to handle vision processing + + Args: + prompt: Can be string or tuple (image_path, text_prompt) + generation_len: Generation length + prefill_logit_bs: Prefill batch size + decode_batch_id: Batch ID for continuous batching + + Returns: + Same as base class: (outputs, position_ids, generation_len) + """ + # Normalize prompt: TextGeneration passes a list even for batch_size=1 + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], tuple) and len(prompt[0]) == 2: + # Unwrap single (image_path, text_prompt) tuple + if len(prompt) == 1: + prompt = prompt[0] + else: + raise NotImplementedError( + "VisionLanguageGeneration.run_prefill currently supports a single (image, text) pair per call." + ) + # Check if this is a vision-language prompt + if isinstance(prompt, tuple) and len(prompt) == 2: + image_path, text_prompt = prompt + + # Check cache for vision outputs + cache_key = image_path if isinstance(image_path, str) else str(image_path) + if cache_key in self._vision_cache: + lang_inputs, vision_outputs, num_chunks = self._vision_cache[cache_key] + logger.debug(f"Using cached vision outputs for {cache_key}") + else: + # Build language inputs with processor-aware vision/text integration + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len + ) + # Cache for future use + self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks) + logger.debug(f"Cached vision outputs for {cache_key}") + + # Set vision buffers in language session + self._session.set_buffers(vision_outputs) + logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") + self._vision_processed = True + self._vision_outputs = vision_outputs + + # Calculate generation_len consistent with ctx_len + max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Execute chunked prefill + outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs) + + self._session.skip_buffers(vision_outputs) + + # Prepare position_ids for decode phase (next position after prefill) + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + return outputs, position_ids_decode, generation_len + else: + # Fall back to base class for text-only + return super().run_prefill(prompt, generation_len, prefill_logit_bs, decode_batch_id) + + def _prepare_vision_language_prompt(self, text_prompt, image_path): + """ + Prepare text prompt with vision context + + This method handles the integration of vision and text inputs + according to the specific model's requirements. + """ + # For most vision-language models, we need to apply the chat template + # that includes both image and text components + try: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + processed_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + return processed_prompt + + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}. Using original prompt.") + return text_prompt + + def generate( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs + ) -> CloudAI100ExecInfo: + """ + Main generation method maintaining API compatibility with VisionLanguageGeneration + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + stream: Enable streaming output + **kwargs: Additional arguments passed to base class + + Returns: + CloudAI100ExecInfo with results and metrics + + Raises: + ValueError: If images and prompts lengths don't match + """ + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + # Clear vision cache for fresh generation + self._vision_cache.clear() + + logger.info(f"Generating for {len(images)} image-prompt pairs") + + # Convert to base class format: list of (image, prompt) tuples + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + # Use base class generate method with vision prompts + if self.full_batch_size is not None: + # Continuous batching mode (new capability) + return self._generate_continuous_batching(vision_prompts, generation_len, stream, **kwargs) + else: + # Regular batching mode + return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) + + def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Handle regular batching for vision-language generation without creating a second language session""" + batch_results = [] + for i in range(0, len(vision_prompts), self.batch_size): + batch = vision_prompts[i : i + self.batch_size] + + if stream: + print( + f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}" + ) + for j, (img, prompt) in enumerate(batch): + print(f"Image: {img}") + print(f"Prompt: {prompt}") + print("Completion:", flush=True, end="") + + # Setup decode storage arrays for this batch (use ctx_len or generation_len whichever is larger) + exec_batch_size = self.batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + self.initialize_decode_inputs( + num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length + ) + + # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) + start = perf_counter() + outputs, position_ids, generation_len_final = self.run_prefill( + batch, generation_len, prefill_logit_bs=self.batch_size + ) + self.update_decode_input(outputs, position_ids, generation_len_final) + + # Prepare decode + decode_inputs = self.prepare_decode_inputs() + + # Decode loop + loop_start = perf_counter() + num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None) + end = perf_counter() + + # Decode generated texts + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + # Latency metrics + total_decode_tokens = num_token + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end + ) + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + # Package result for this batch + batch_results.append( + CloudAI100ExecInfo( + batch_size=self.batch_size, + generated_texts=generated_texts, + generated_ids=self.generated_ids, + perf_metrics=perf_metrics, + ) + ) + + # Aggregate results across batches + return self._aggregate_batch_results(batch_results) + + def _generate_continuous_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Enable continuous batching for vision-language models (new capability)""" + logger.info("Using continuous batching for vision-language generation") + + if stream: + logger.warning("Streaming output not fully supported with continuous batching") + + # Reset vision processing state for new generation + self._vision_processed = False + self._vision_outputs = None + self._vision_outputs_cache = {} + + # Initialize decode inputs + num_prompts = len(vision_prompts) + execution_batch_size = self.full_batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + + self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + if self.is_qwen2_5_vl: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) + + # Create prompt queue + prompt_queue = deque(vision_prompts) + + start = perf_counter() + + # Pre-process ALL vision inputs and cache them + logger.info("Pre-processing all vision inputs...") + for batch_id in range(min(self.full_batch_size, len(vision_prompts))): + img, prompt = vision_prompts[batch_id] + + # Process vision for this slot + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len + ) + + # Cache vision outputs for this batch slot + self._vision_outputs_cache[batch_id] = { + "vision_outputs": vision_outputs, + "lang_inputs": lang_inputs, + "num_chunks": num_chunks, + } + + logger.debug(f"Cached vision outputs for batch_id {batch_id}") + + # Reset prompt queue for prefill + prompt_queue = deque(vision_prompts) + + self.batch_index = None + + # Run prefill for all inputs using cached vision + self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len) + + # Set vision buffers for decode (use first slot's vision for now) + # For identical images, any slot's vision works + cached_slot_0 = self._vision_outputs_cache.get(0) + if cached_slot_0: + self._session.set_buffers(cached_slot_0["vision_outputs"]) + logger.debug("Set vision buffers from slot 0 for decode phase") + + # Now set batch_index for decode phase + self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) + + loop_start = perf_counter() + decode_pause_time = self.run_continuous_batching_decode(prompt_queue, generation_len) + end = perf_counter() + + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + total_decode_tokens = sum( + np.sum(self.generated_ids[i] != self.tokenizer.pad_token_id) - 1 for i in range(len(vision_prompts)) + ) + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end, decode_pause_time + ) + prefill_time /= len(vision_prompts) # Average prefill time for continuous batching + + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + return CloudAI100ExecInfo( + batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics + ) + + def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs using pre-cached vision outputs. + + This avoids the vision buffer overwriting issue by using cached vision + outputs instead of processing vision during each prefill iteration. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + """ + for decode_batch_id in range(self.full_batch_size): + # Pop the promt as we are processing + _ = prompt_queue.popleft() + + # Get cached vision outputs for this batch slot + cached = self._vision_outputs_cache.get(decode_batch_id) + if cached: + vision_outputs = cached["vision_outputs"] + lang_inputs = cached["lang_inputs"] + num_chunks = cached["num_chunks"] + + # Set vision buffers for THIS prefill + self._session.set_buffers(vision_outputs) + logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill") + + # Run prefill with cached inputs + outputs = self._execute_chunked_prefill( + lang_inputs, + num_chunks, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + prefill_logit_bs=1, + ) + + self._session.skip_buffers(vision_outputs.keys()) + + # Calculate position_ids for decode + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + # Calculate generation_len + max_gen_len = ( + self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + ) + generation_len_final = self._fetch_generation_len(generation_len, max_gen_len) + + # Update decode inputs + if self.is_qwen2_5_vl: + self.update_decode_inputs_qwen2_5_vl( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) + else: + self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) + else: + logger.error(f"No cached vision outputs for batch_id {decode_batch_id}") + raise RuntimeError(f"Vision outputs not cached for batch_id {decode_batch_id}") + + def prepare_decode_inputs(self): + """ + Override base class to handle vision-specific decode inputs + """ + decode_inputs = super().prepare_decode_inputs() + + # Add image_idx for vision-language models in CB mode during decode only + if self.batch_index is not None and hasattr(self, "_vision_outputs"): + # image_idx should be a single slot selector; decoder expects shape (1,1) + # Query binding dims if available to be robust + try: + if "image_idx" in getattr(self._session, "binding_index_map", {}): + idx = self._session.binding_index_map["image_idx"] + dims = tuple(self._session.bindings[idx].dims) + decode_inputs["image_idx"] = np.zeros(dims, dtype=np.int64) + else: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + except Exception: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + + # Include cross_attention_mask during decode if present/required + if hasattr(self, "_decode_cross_attention_mask") and self._decode_cross_attention_mask is not None: + # Decoder specialization expects a single mask (batch dim = 1) + decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask + + return decode_inputs + + def _aggregate_batch_results(self, batch_results): + """Aggregate results from multiple batches""" + if not batch_results: + raise ValueError("No batch results to aggregate") + + if len(batch_results) == 1: + return batch_results[0] + + # Aggregate multiple batch results + all_generated_texts = [] + all_generated_ids = [] + all_metrics = [] + + for result in batch_results: + if isinstance(result.generated_texts[0], list): + # Flatten nested lists + all_generated_texts.extend([text for batch in result.generated_texts for text in batch]) + else: + all_generated_texts.extend(result.generated_texts) + + if isinstance(result.generated_ids, list): + all_generated_ids.extend(result.generated_ids) + else: + all_generated_ids.append(result.generated_ids) + + all_metrics.append(result.perf_metrics) + + # Average metrics + avg_metrics = PerfMetrics( + prefill_time=np.mean([m.prefill_time for m in all_metrics]), + decode_perf=np.mean([m.decode_perf for m in all_metrics]), + total_perf=np.mean([m.total_perf for m in all_metrics]), + total_time=np.mean([m.total_time for m in all_metrics]), + ) + + return CloudAI100ExecInfo( + batch_size=batch_results[0].batch_size, + generated_texts=all_generated_texts, + generated_ids=all_generated_ids, + perf_metrics=avg_metrics, + ) + + def generate_stream_tokens( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, **kwargs + ): + """ + Enable token-by-token streaming for vision models (new capability) + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + **kwargs: Additional arguments + + Yields: + List of decoded tokens for each batch position + + Raises: + NotImplementedError: If continuous batching is enabled + """ + if self.full_batch_size is not None: + raise NotImplementedError("Token streaming not supported with continuous batching for VLM") + + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + logger.info(f"Starting token streaming for {len(images)} image-prompt pairs") + + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + text_gen = TextGeneration( + tokenizer=self.tokenizer, + qpc_path=self._qpc_path, + ctx_len=self._ctx_len, + device_id=self.device_id, + enable_debug_logs=self.enable_debug_logs, + is_tlm=self.is_tlm, + include_sampler=self.include_sampler, + return_pdfs=self.return_pdfs, + sampling_params=self.sampling_params, + ) + + text_gen._qaic_model = self + + # Yield tokens as they're generated + for tokens in text_gen.generate_stream_tokens(vision_prompts, generation_len, **kwargs): + yield tokens + + def __repr__(self): + """String representation of the class""" + return ( + f"VisionLanguageGeneration(" + f"batch_size={self.batch_size}, " + f"ctx_len={self._ctx_len}, " + f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " + f"sampling={'enabled' if self.include_sampler else 'disabled'})" + ) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..853567be9 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -537,3 +537,122 @@ def update( ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS: + def __init__(self, config, batch_size, max_cache_len, sliding_window_len): + self.max_cache_len = max_cache_len + self.batch_size = batch_size + self.sliding_window_len = sliding_window_len + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "HybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls( + config, + batch_size=past_key_values[0][0].shape[0], + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids) + else: + scatter_position_ids = kv_position_ids + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + ctx_len = self.key_cache[layer_idx].shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index c692d1beb..5337b44f5 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -185,6 +185,7 @@ ] ) +# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # Define a transformers layers to QEff layers dictionary diff --git a/QEfficient/transformers/models/gpt_oss/__init__.py b/QEfficient/transformers/models/gpt_oss/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 000000000..62bc849b7 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,736 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssConfig, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRotaryEmbedding, + repeat_kv, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +class QEffGptOssExperts(GptOssExperts): + def __qeff_init__(self): + self.gate_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.gate_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + + +class QEffGptOssMLP(GptOssMLP): + def alt_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + # ------------------- Gather based, weights as activation approach --------------- + def forward_weights_as_activation(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts + gate_up_proj = self.experts.gate_up_proj[router_indices.flatten()] + gate_up_proj_bias = self.experts.gate_up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Apply Chosen Experts (without routing weights first) + # expert_in = hidden_states.repeat_interleave(self.router.top_k, dim=0) + # expert_in = expert_in.view(-1, 1, self.experts.hidden_size) + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + gate_up = torch.bmm(expert_in, gate_up_proj) + gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation (This is before on Llama4) + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- + def forward(self, hidden_states): + # print("Seperate Split, Up, Gate Projections") + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts (separate gate and up projections) + gate_proj = self.experts.gate_proj[router_indices.flatten()] + gate_proj_bias = self.experts.gate_proj_bias[router_indices.flatten()] + up_proj = self.experts.up_proj[router_indices.flatten()] + up_proj_bias = self.experts.up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + # Apply gate and up projections separately using bmm + gate = torch.bmm(expert_in, gate_proj) + gate_proj_bias.unsqueeze(1) + up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + # Down projection + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + def optimized_moe_forward(self, hidden_states: torch.Tensor): + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + + # Top-k selection + top_w, selected_experts = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + # Creating experts mask and routing weights masked + awesome_experts_mask_1 = ( + torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_2 = ( + torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_3 = ( + torch.nn.functional.one_hot(selected_experts[:, 2], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_4 = ( + torch.nn.functional.one_hot(selected_experts[:, 3], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + + gateupout1 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout2 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout3 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout4 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + + # Gate and Up projections + gate = (hidden_states @ W_g) + b_g # [T, I] + up = (hidden_states @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + gateupout1 += torch.where(awesome_experts_mask_1[e], intermediate, torch.zeros_like(gateupout1)) + gateupout2 += torch.where(awesome_experts_mask_2[e], intermediate, torch.zeros_like(gateupout2)) + gateupout3 += torch.where(awesome_experts_mask_3[e], intermediate, torch.zeros_like(gateupout3)) + gateupout4 += torch.where(awesome_experts_mask_4[e], intermediate, torch.zeros_like(gateupout4)) + + concat_down = torch.zeros((self.router.top_k, T, H)) + concat_mask = torch.cat( + ( + awesome_experts_mask_1.unsqueeze(0), + awesome_experts_mask_2.unsqueeze(0), + awesome_experts_mask_3.unsqueeze(0), + awesome_experts_mask_4.unsqueeze(0), + ), + dim=0, + ) + + concat_gateout = torch.cat( + (gateupout1.unsqueeze(0), gateupout2.unsqueeze(0), gateupout3.unsqueeze(0), gateupout4.unsqueeze(0)), dim=0 + ) + + for e in range(self.experts.num_experts): + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Down projection + down_out = (concat_gateout @ W_d) + b_d # [T, H] + + concat_down += torch.where(concat_mask[:, e, :], down_out, torch.zeros_like(concat_down)) + + downout1, downout2, downout3, downout4 = concat_down[0], concat_down[1], concat_down[2], concat_down[3] + hidden_states = ( + downout1 * top_w[:, 0].unsqueeze(-1) + + downout2 * top_w[:, 1].unsqueeze(-1) + + downout3 * top_w[:, 2].unsqueeze(-1) + + downout4 * top_w[:, 3].unsqueeze(-1) + ).reshape(B, S, H) + + # original shape [B, S, H] + return hidden_states, router_logits + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: GptOssConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssDecoderLayer(GptOssDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + # alth, _ = self.mlp.alt_forward(hidden_states) + hidden_states = hidden_states.reshape(residual.shape) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.sliding_window_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_pkv_dynamic_axes( + self, + ): + pkv_dynamic_axes = [] + for layer_type in self.config.layer_types: + if layer_type == "sliding_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) + elif layer_type == "full_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + return pkv_dynamic_axes + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + ): + batch_size = batch_size if batch_size else 1 + prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN + ctx_len = ctx_len if ctx_len else constants.CTX_LEN + + specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": 128, + }, + { + "batch_size": batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "sliding_window": 128, + }, + ] + return specializations diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..b7b951101 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -820,7 +820,7 @@ def forward(self, pixel_values): ) vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.model.multi_modal_projector(vision_flat) - return projected_vision_flat + return projected_vision_flat # , pixel_values # This wrapper utilizes the 'vision_embeds', which contains vision embeddings, and an 'image_idx' index starting at 0. @@ -836,7 +836,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +854,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -893,6 +905,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) @@ -941,28 +956,42 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -971,18 +1000,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -1011,6 +1044,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: + # vision_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") output_names["vision"] = vision_output_names @@ -1045,7 +1079,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1090,10 +1124,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1102,6 +1140,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..9da0e183c 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn from transformers import ( + AutoImageProcessor, AutoModel, AutoModelForCausalLM, AutoModelForCTC, @@ -26,7 +27,13 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import ( + CustomOpTransform, + FP16ClipTransform, + OnnxSlimTransform, + RenameFunctionOutputsTransform, + SplitTensorsTransform, +) from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( @@ -35,6 +42,7 @@ calculate_latency, get_compilation_dims, ) +from QEfficient.generation.vlm_generation import VisionLanguageGeneration from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, @@ -51,6 +59,7 @@ AwqToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, ) from QEfficient.utils import ( constants, @@ -856,6 +865,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching: bool = False, **kwargs, ): """ @@ -879,6 +889,7 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @property @@ -978,8 +989,15 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, continuous_batching=self.continuous_batching + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1011,7 +1029,6 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, **compiler_options, @@ -1068,14 +1085,20 @@ def compile( If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. If both `skip_lang` and `skip_vision` are True. """ - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) @@ -1085,6 +1108,9 @@ def compile( ctx_len=ctx_len, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -1111,6 +1137,11 @@ def compile( ): self.export() + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + if not skip_vision: self.vision_model._compile( compile_dir=compile_dir, @@ -1156,7 +1187,11 @@ def compile( def generate( self, - inputs: torch.Tensor, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor: Optional[AutoImageProcessor] = None, + images: List[str] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, @@ -1172,6 +1207,14 @@ def generate( inputs : Dict[str, Union[torch.Tensor, np.ndarray]] Inputs to run the execution, typically includes `pixel_values`, `input_ids`, `attention_mask`, etc. + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast, optional + Tokenizer for the model. Used when images and prompts are provided. + processor : AutoImageProcessor, optional + Processor for the model. Used when images and prompts are provided. + images : List[str], optional + List of image paths or PIL images to process. + prompts : List[str], optional + List of text prompts corresponding to the images. streamer : TextStreamer, optional A streamer object to display generated tokens in real-time. Default is None. device_ids : List[int], optional @@ -1196,6 +1239,30 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + # Use VisionLanguageGeneration for image-prompt pairs + if (processor and images) or (tokenizer and prompts): + # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) + vlm_gen = VisionLanguageGeneration( + qeff_model=self, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, + ) + + # Call generate method + return vlm_gen.generate( + images=images, + prompts=prompts, + generation_len=generation_len, + stream=streamer is not None, + ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) @@ -1332,9 +1399,7 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) - # Prepare inputs for prefill - chunk_inputs = lang_inputs.copy() - prefill_start = perf_counter() + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() @@ -1346,7 +1411,7 @@ def kv_offload_generate( outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] - prefill_time = perf_counter() - prefill_start + vision_end - vision_start + prefill_time = perf_counter() - lang_start + vision_end - vision_start # Skip inputs/outputs again lang_session.skip_buffers( [ @@ -1355,6 +1420,8 @@ def kv_offload_generate( if x.startswith("past_") or x.endswith("_RetainedState") ] ) + if not_mllama: + lang_session.skip_buffers(vision_outputs.keys()) # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) @@ -1930,7 +1997,7 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -1951,13 +2018,19 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) The wrapped model instance, configured for either dual or single QPC. """ if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) else: return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + **kwargs, + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -1986,18 +2059,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona If `continuous_batching` is provided as True. """ # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") + if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -2032,12 +2111,19 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + _onnx_transforms = [ + FP16ClipTransform, + CustomOpTransform, + RenameFunctionOutputsTransform, + OnnxSlimTransform, + SplitTensorsTransform, + ] def __init__( self, @@ -2285,14 +2371,25 @@ def export(self, export_dir: Optional[str] = None) -> str: for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32)) dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes - output_names.append(f"past_{kv}.{i}_RetainedState") + output_names.append(f"past_{kv}.{i}_InternalRetainedState") else: + # HACK: create common function for this including above if condition code + pkv_dynamic_axes = ( + self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * self.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) + for i in range(self.num_layers): + pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size" for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes - output_names.append(f"past_{kv}.{i}_RetainedState") + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_{kv}.{i}_InternalRetainedState") if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) @@ -2425,12 +2522,19 @@ def build_prefill_specialization( Dict[str, Union[int, str]] A dictionary defining the prefill specialization. """ - spec = { - "batch_size": 1 if self.continuous_batching else batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_logits_to_keep": 1 if self.is_tlm else None, - } + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=1 if self.continuous_batching else batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + )[0] + else: + spec = { + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + spec["num_logits_to_keep"] = 1 if self.is_tlm else None if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -2474,12 +2578,20 @@ def build_decode_specialization( """ if prefill_seq_len == 1 and not self.continuous_batching: return None # Avoid duplication with prefill - spec = { - "batch_size": full_batch_size if self.continuous_batching else batch_size, - "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, - "ctx_len": ctx_len, - "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, - } + + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=full_batch_size if self.continuous_batching else batch_size, + prefill_seq_len=(num_speculative_tokens + 1) if self.is_tlm else 1, + ctx_len=ctx_len, + )[1] + else: + spec = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, + "ctx_len": ctx_len, + } + spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -2705,8 +2817,8 @@ def generate( raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( - tokenizer, - self.qpc_path, + tokenizer=tokenizer, + qpc_path=self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 23ab2ca5f..15410da4d 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -51,6 +51,15 @@ GPTBigCodeForCausalLM, GPTBigCodeModel, ) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRMSNorm, +) from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel from transformers.models.granite.modeling_granite import ( GraniteAttention, @@ -243,6 +252,14 @@ QEffGPTBigCodeForCausalLM, QEffGPTBigCodeModel, ) +from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssAttention, + QEffGptOssDecoderLayer, + QEffGptOssExperts, + QEffGptOssForCausalLM, + QEffGptOssMLP, + QEffGptOssModel, +) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, QEffGPTJBlock, @@ -417,6 +434,7 @@ class CustomOpsTransform(ModuleMappingTransform): _module_mapping = { GemmaRMSNorm: GemmaCustomRMSNormAIC, Gemma2RMSNorm: GemmaCustomRMSNormAIC, + GptOssRMSNorm: CustomRMSNormAIC, LlamaRMSNorm: CustomRMSNormAIC, Llama4TextRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, @@ -502,6 +520,13 @@ class KVCacheTransform(ModuleMappingTransform): Gemma3TextModel: QEffGemma3TextModel, Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # GPT_OSS + GptOssAttention: QEffGptOssAttention, + GptOssDecoderLayer: QEffGptOssDecoderLayer, + GptOssModel: QEffGptOssModel, + GptOssForCausalLM: QEffGptOssForCausalLM, + GptOssMLP: QEffGptOssMLP, + GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, @@ -796,3 +821,46 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") return model, transformed + + +# def get_decoder_layer_classes_for_export(model: nn.Module) -> set: +# """ +# Dynamically determine which DecoderLayer classes should be exported as functions +# based on the model's architecture using the existing KVCacheTransform mapping. +# """ +# # Define patterns that identify decoder layer classes +# DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] + +# # Get all QEff classes that are decoder layers from the existing mapping +# decoder_layer_classes = set() + +# for original_class, qeff_class in KVCacheTransform._module_mapping.items(): +# # Check if the QEff class name contains decoder layer patterns +# qeff_class_name = qeff_class.__name__ +# if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): +# decoder_layer_classes.add(qeff_class) + +# # Filter to only include classes that are actually used in the current model +# model_decoder_classes = set() +# for module in model.modules(): +# if module.__class__ in decoder_layer_classes: +# model_decoder_classes.add(module.__class__) + +# return model_decoder_classes + + +def get_decoder_layer_classes_for_export(model: nn.Module) -> set: + """ + Dynamically determine which DecoderLayer classes should be exported as functions + based on the model's architecture using the existing KVCacheTransform mapping. + """ + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block"] + + decoder_layer_classes = set() + + for module in model.modules(): + class_name = module.__class__.__name__ + if any(pattern in class_name for pattern in DECODER_LAYER_PATTERNS): + decoder_layer_classes.add(module.__class__) + + return decoder_layer_classes diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index e5e842e6f..445c15583 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import math +import os from typing import Callable, List, Optional, Tuple, Union import torch @@ -360,7 +361,7 @@ def forward(self, x, seq_len=None): ) -def eager_attention_forward( +def eager_attention_forward_q_blocked( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -368,22 +369,107 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], **kwargs, ): + """ + Q-blocked attention for Qwen2.5-VL. + Blocks only the query SL dimension. + + Args: + query: (BS, NH, Q_LEN, DH) + key: (BS, NH_KV, KV_LEN, DH) + value: (BS, NH_KV, KV_LEN, DH) + attention_mask: (BS, NH, Q_LEN, KV_LEN) or broadcastable + """ + BS, NH, Q_LEN, DH = query.shape + _, _, KV_LEN, _ = key.shape + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + target_blocks_q = int(os.environ.get("num_q_blocks", Q_LEN)) + q_block_positions = [(i * Q_LEN) // target_blocks_q for i in range(target_blocks_q)] + scaling = 1.0 / math.sqrt(module.head_dim) + + q_output_blocks = [] + q_attn_weights_blocks = [] - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # Process each Q block + for q_block_idx in range(target_blocks_q): + qi = q_block_positions[q_block_idx] + + # Calculate Q block size + if q_block_idx == target_blocks_q - 1: + real_q_len = Q_LEN - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + # Extract Q block + q_block = query[:, :, qi : qi + real_q_len, :] + attn_mask_block = None + if attention_mask is not None: + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + # Compute attention scores for this Q block + attn_weights = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + if attn_mask_block is not None: + attn_weights = torch.where( + attn_mask_block, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=attn_weights.device), + attn_weights, + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # Compute output for this Q block + output_block = torch.matmul(attn_weights, value_states) + + q_output_blocks.append(output_block) + q_attn_weights_blocks.append(attn_weights) + + attn_output = torch.cat(q_output_blocks, dim=2) attn_output = attn_output.transpose(1, 2).contiguous() + # Concatenate attention weights + attn_weights = torch.cat(q_attn_weights_blocks, dim=2) + return attn_output, attn_weights +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + **kwargs, +): + """ + Wrapper that routes to blocked or default attention based on environment variable. + """ + blocking_mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower() + + if blocking_mode == "q": + return eager_attention_forward_q_blocked(module, query, key, value, attention_mask, **kwargs) + elif blocking_mode == "default": + # Original implementation + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + else: + raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {blocking_mode}. Must be 'q' or 'default'") + + class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -680,7 +766,15 @@ def __init__(self, model): self.model = model self.language_model = self.model.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id @@ -691,7 +785,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) @@ -709,7 +807,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -745,10 +843,14 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): .repeat(4, 1, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( - config=self.model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -757,6 +859,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -775,7 +880,11 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + num_frames: int = 1, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): if height is None or width is None: @@ -844,6 +953,7 @@ def smart_resize( grid_height = grid_h * grid_w grid_width = patch_size * patch_size * temporal_patch_size * channel vision_size = grid_height // 4 + vision_size = vision_size * num_frames grid_height = grid_height * batch_size vision = [ @@ -856,20 +966,37 @@ def smart_resize( "grid_w": grid_w, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -878,9 +1005,11 @@ def smart_resize( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -892,12 +1021,21 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {1: "batch_size", 2: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} dynamic_axes = {} diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index d647b73a6..dfadc00ef 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -4,3 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers + +__all__ = ["replace_transformers_quantizers"] diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py index ba204e419..d73909211 100644 --- a/QEfficient/transformers/quantizers/auto.py +++ b/QEfficient/transformers/quantizers/auto.py @@ -11,7 +11,8 @@ from transformers.quantizers.quantizer_awq import AwqQuantizer from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer from transformers.quantizers.quantizer_gptq import GptqHfQuantizer -from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig, Mxfp4Config from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer from QEfficient.transformers.quantizers.quantizer_compressed_tensors import ( @@ -21,30 +22,35 @@ QEffFP8Quantizer, ) from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4Config, QEffMxfp4HfQuantizer QEFF_AUTO_QUANTIZER_MAPPING = { "awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer, "compressed-tensors": QEffCompressedTensorsFP8Quantizer, "fp8": QEffFP8Quantizer, + "mxfp4": QEffMxfp4HfQuantizer, } QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": QEffAwqConfig, "gptq": QEffGPTQConfig, "compressed-tensors": QEffCompressedTensorsConfig, "fp8": QEffFP8Config, + "mxfp4": QEffMxfp4Config, } DUPLICATE_AUTO_QUANTIZER_MAPPING = { "awq": AwqQuantizer, "gptq": GptqHfQuantizer, "compressed-tensors": CompressedTensorsHfQuantizer, "fp8": None, + "mxfp4": Mxfp4HfQuantizer, } DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": AwqConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "fp8": None, + "mxfp4": Mxfp4Config, } diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index 0427bca37..69d6380f0 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -7,13 +7,19 @@ import torch from torch import nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts from QEfficient.base.pytorch_transforms import ModuleMutatorTransform from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear -from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts +from QEfficient.transformers.quantizers.quantizer_utils import ( + convert_moe_packed_tensors, + dequantize_gptq, + unpack_weights, +) class AwqToMatmulNbitsTransform(ModuleMutatorTransform): @@ -115,3 +121,28 @@ def mutate(cls, original_module, parent_module): if original_module.bias is not None: dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float()) return dequant_linear_layer + + +class Mxfp4GptOssExpertDequantizeTransform(ModuleMutatorTransform): + """ + Used to dequantize the weights of an Mxfp4GptOssExpert module and replace with transformers GptOssExperts with dequantized weights + """ + + _match_class = QEffMxfp4GptOssExperts + + @classmethod + def mutate(cls, original_module, parent_module): + dequant_module = GptOssExperts(original_module.config) + dequant_module.gate_up_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.gate_up_proj_blocks, original_module.gate_up_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.down_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.down_proj_blocks, original_module.down_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias + dequant_module.down_proj_bias = original_module.down_proj_bias + return dequant_module diff --git a/QEfficient/transformers/quantizers/quantizer_mxfp4.py b/QEfficient/transformers/quantizers/quantizer_mxfp4.py new file mode 100644 index 000000000..2ffba1bea --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_mxfp4.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import re +from typing import Optional + +import torch +import torch.nn as nn +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import Mxfp4Config + +from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert +from QEfficient.utils.logging_utils import logger + + +class QEffMxfp4GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + + self.gate_up_proj_blocks = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + ) + + self.down_proj_blocks = nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + ) + self.alpha = 1.702 + self.limit = 7.0 + + self.gate_up_proj_precision_config = None + self.down_proj_precision_config = None + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + gate_up_proj = convert_moe_packed_tensors( + self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32 + ) + down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32) + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +def should_convert_module(current_key_name, patterns): + current_key_name_str = ".".join(current_key_name) + if not any( + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns + ): + return True + return False + + +class QEffMxfp4Config(Mxfp4Config): + """ + Currently there is not need to change the implementation of Mxfp4Config + This is placeholder for future when we would want to change this + """ + + pass + + +class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer): + def validate_environment(self, *args, **kwargs): + return True + + def update_torch_dtype(self, torch_dtype): + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading( + self, + model: torch.nn.Module, + keep_in_fp32_modules: Optional[list[str]] = None, + **kwargs, + ): + self.modules_to_not_convert = get_keys_to_not_convert(model) + self.modules_to_not_convert = ( + ["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert + ) + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) + config = model.config + + # -- Defining local method as it uses lot of local variables -- + def _replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + ): + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + if not should_convert_module(current_key_name, modules_to_not_convert): + current_key_name.pop(-1) + continue + if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: + model._modules[name] = QEffMxfp4GptOssExperts(config) + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + current_key_name.pop(-1) + return model, has_been_replaced + + _replace_with_mxfp4_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py index a318fb8e4..424692d08 100644 --- a/QEfficient/transformers/quantizers/quantizer_utils.py +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import copy +import math import torch from torch import nn @@ -378,3 +379,70 @@ def repack_zeros(qzeros, bits): break qzeros = qzeros.T return qzeros + + +FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + """ + reference for this function is taken from: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt_oss#L98 + """ + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + out = out.to(dtype).permute(0, 2, 1).contiguous() + return out diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index e487d4af4..49f0ad30b 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -10,6 +10,7 @@ undo_transformers_quantizers, ) from QEfficient.utils._utils import ( # noqa: F401 + LRUCache, check_and_assign_cache_dir, create_json, create_model_params, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index abe383556..d58f54952 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -33,6 +33,36 @@ from QEfficient.utils.logging_utils import logger +class LRUCache: + """Simple LRU cache with size limit for vision outputs""" + + def __init__(self, max_size=100): + self._cache = {} + self._access_order = [] + self._max_size = max_size + + def get(self, key): + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + return None + + def put(self, key, value): + if key in self._cache: + self._access_order.remove(key) + elif len(self._cache) >= self._max_size: + oldest = self._access_order.pop(0) + del self._cache[oldest] + + self._cache[key] = value + self._access_order.append(key) + + def clear(self): + self._cache.clear() + self._access_order.clear() + + class DownloadRetryLimitExceeded(Exception): """ Used for raising error when hf_download fails to download the model after given max_retries. diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 5f7a4db7b..5aab14520 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,7 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -84,7 +84,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100" diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index eb1f7c8e6..7d07db530 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -87,13 +87,20 @@ def prepare_pytorch_inputs(self): if self.full_batch_size: inputs["input_ids"] = input_ids - inputs["position_ids"] = torch.arange(input_len).view(1, input_len) - inputs["batch_index"] = torch.arange(1).view(-1, 1) + inputs["position_ids"] = position_ids + inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) past_key_values = [] for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + if ( + all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) + and self.config.layer_types[i] == "sliding_attention" + ): + pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] + else: + pad_shape = self.padding_shape + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = tuple(past_key_values) @@ -113,18 +120,15 @@ def update_pytorch_inputs(self, inputs, pt_outputs): """ updated_inputs = {} if self.full_batch_size: - batch_index = torch.arange(1).view(-1, 1) - input_ids = pt_outputs.logits.detach().argmax(2) updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) - updated_inputs["input_ids"][batch_index.view(-1)] = input_ids + updated_inputs["input_ids"][inputs["batch_index"].view(-1)] = input_ids position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["position_ids"] = torch.full((self.full_batch_size, 1), 0) - updated_inputs["position_ids"][batch_index.view(-1)] = position_ids - - updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) + updated_inputs["position_ids"][inputs["batch_index"].view(-1)] = position_ids + updated_inputs["batch_index"] = inputs["batch_index"] else: updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 @@ -169,8 +173,17 @@ def prepare_ort_inputs(self): inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32) else: for i in range(self.n_layer): - inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) - inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + if ( + all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) + and self.config.layer_types[i] == "sliding_attention" + ): + pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] + else: + pad_shape = self.padding_shape + inputs["past_key." + str(i)] = np.zeros((pad_shape), dtype=np.float32) + inputs["past_value." + str(i)] = np.zeros((pad_shape), dtype=np.float32) + if self.full_batch_size: + inputs["batch_index"] = np.arange(self.full_batch_size).reshape(-1, 1) return inputs def update_ort_inputs(self, inputs, ort_outputs): @@ -191,7 +204,8 @@ def update_ort_inputs(self, inputs, ort_outputs): for i in range(self.n_layer): updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] - + if self.full_batch_size: + updated_inputs["batch_index"] = inputs["batch_index"] return updated_inputs def update_ort_outputs(self, ort_outputs): diff --git a/QEfficient/utils/patches.py b/QEfficient/utils/patches.py new file mode 100644 index 000000000..a652bbb2a --- /dev/null +++ b/QEfficient/utils/patches.py @@ -0,0 +1,120 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Monkey patches for torch.onnx.utils to fix ONNX export issues.""" + +from typing import Collection, Set, Type, Union + +import torch +import torch.onnx.utils as onnx_utils +from torch import _C + + +def _setup_trace_module_map_patched( + model: Union[torch.nn.Module, torch.jit.ScriptModule], + export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]], +) -> Set[str]: + """Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch.""" + + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + # FIX: use empty dict to avoid type mismatch with _jit_pass_onnx_track_scope_attributes + # Observed in transformers v4.55 and above + onnx_attrs = {} + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + """ + Parse qualified variable name and return the unqualified version. + Pure numeric atoms are considered inadequate, so this function will look past them, + and start from the first non-numeric atom. + """ + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n)) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be " + "passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _get_module_attributes(module): + """Helper function to get module attributes safely.""" + import typing + + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def apply_torch_patches(): + """Apply all necessary torch patches for ONNX export.""" + # Monkey patch the function + onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched + + if hasattr(onnx_utils, "_get_module_attributes"): + onnx_utils._get_module_attributes = _get_module_attributes + + print("Applied torch ONNX export patches for export_modules_as_functions compatibility") + + +def is_patched(): + """Check if patches have been applied.""" + return onnx_utils._setup_trace_module_map == _setup_trace_module_map_patched diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py new file mode 100644 index 000000000..24d050e97 --- /dev/null +++ b/examples/gpt_oss.py @@ -0,0 +1,35 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +onnx_model_path = qeff_model.export() +qpc_path = qeff_model.compile( + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=256, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=8, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + prompts="Who is your creator? and What all you are allowed to do?", + device_id=[0, 1, 2, 3], +) diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py new file mode 100644 index 000000000..f285ea278 --- /dev/null +++ b/examples/llama4_CB_example_vision_lang.py @@ -0,0 +1,93 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +continious_batching = False +if continious_batching: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=100, +) + +# print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/examples/qwen2_5_vl_CB.py b/examples/qwen2_5_vl_CB.py new file mode 100644 index 000000000..96ef4898a --- /dev/null +++ b/examples/qwen2_5_vl_CB.py @@ -0,0 +1,72 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/qwen2_5_vl_example.py b/examples/qwen2_5_vl_example.py index 374f70ad2..d5d943c9c 100644 --- a/examples/qwen2_5_vl_example.py +++ b/examples/qwen2_5_vl_example.py @@ -5,6 +5,9 @@ # # ----------------------------------------------------------------------------- +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + import requests import transformers from PIL import Image diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 86bce4441..321a466ab 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -25,6 +25,7 @@ from QEfficient.utils.test_utils import ModelConfig test_models_causal = [ + "openai/gpt-oss-20b", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "gpt2", "Salesforce/codegen-350M-mono", @@ -76,11 +77,11 @@ def get_custom_n_layers(model_name): :return n_layer """ - if model_name in {"microsoft/Phi-3-mini-4k-instruct", "neuralmagic/Qwen2-0.5B-Instruct-FP8"}: + if model_name in {"microsoft/Phi-3-mini-4k-instruct", "neuralmagic/Qwen2-0.5B-Instruct-FP8", "openai/gpt-oss-20b"}: return 2 elif model_name in ModelConfig.SWIFTKV_MODELS: return None - return 16 + return 1 def load_causal_lm_model(model_name, n_layer=1, config=None): @@ -157,6 +158,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( """ replace_transformers_quantizers() if config is None: + n_layer = get_custom_n_layers(model_name) model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) else: model_hf, _ = load_causal_lm_model(model_name, config=config) diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index bdc15519e..0810ac6ba 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -33,6 +33,7 @@ ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ @@ -177,12 +178,23 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", 2: "ctx_len", } + pkv_dynamic_axes = ( + qeff_model.model.get_pkv_dynamic_axes() + if hasattr(qeff_model.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * qeff_model.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) output_names = [] output_names.append("logits") for i in range(qeff_model.num_layers): + pkv_dynamic_axes[i][0] = "full_batch_size" if qeff_model.continuous_batching else "batch_size" for kv in ["key", "value"]: - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") if qeff_model.continuous_batching: