diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 33c6f5588..4989e6784 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -11,6 +11,7 @@ import QEfficient.utils.model_registery # noqa: F401 from QEfficient.utils import custom_format_warning from QEfficient.utils.logging_utils import logger +from QEfficient.utils.patches import apply_torch_patches, is_patched # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before @@ -18,10 +19,13 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Placeholder for all non-transformer models registered in QEfficient - # custom warning for the better logging experience warnings.formatwarning = custom_format_warning +# Apply patches +# TODO: Find a better way to do this, this is temp. fix. +apply_torch_patches() + def check_qaic_sdk(): """Check if QAIC SDK is installed""" @@ -70,6 +74,8 @@ def check_qaic_sdk(): "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", + "apply_torch_patches", + "is_patched", ] else: diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..3364ecbc3 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -18,10 +18,13 @@ import onnx import torch -from QEfficient.base.onnx_transforms import OnnxTransform +from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile +from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import ( constants, create_json, @@ -243,6 +246,11 @@ def _export( input_names.append(param) try: + # Initialize the registry with your custom ops + CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm) + CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter) + CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather) + decoder_layer_classes = get_decoder_layer_classes_for_export(self.model) export_kwargs = {} if export_kwargs is None else export_kwargs torch.onnx.export( self.model, @@ -252,6 +260,8 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, + export_modules_as_functions=decoder_layer_classes, + do_constant_folding=True, **export_kwargs, ) logger.info("PyTorch export successful") @@ -261,6 +271,7 @@ def _export( model = onnx.load(tmp_onnx_path, load_external_data=False) transform_kwargs = { "onnx_base_dir": str(tmp_onnx_dir), + "temp_onnx_path": tmp_onnx_path, "model_name": self.model_name, } if onnx_transform_kwargs is not None: diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 61b5c00f6..65287426a 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,9 +5,12 @@ # # ---------------------------------------------------------------------------- -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np +import onnx +import onnxslim +import torch from onnx import ModelProto, external_data_helper, numpy_helper @@ -99,3 +102,164 @@ def apply( current_file_size = tsize external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") return model, transformed + + +class OnnxSlimTransform(OnnxTransform): + """ + Applies onnx-slim transformations on the given ONNX graph. + """ + + @classmethod + def apply( + cls, + model: ModelProto, + *, + onnx_base_dir: Optional[str] = None, + **kwargs, + ) -> Tuple[ModelProto, bool]: + """ + :param enable_onnx_slim_transform: If True, applies onnx-slim transformations. + :param temp_onnx_path: Path to save the slimmed ONNX model. + """ + transformed = False + onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) + temp_onnx_path = kwargs.get("temp_onnx_path", None) + if not temp_onnx_path: + err_str = "temp_onnx_path is required for onnx-slim transform." + raise RuntimeError(err_str) + if onnx_slim_transform: + transformed = True + slimmed_model = onnxslim.slim(model) + onnx.save(slimmed_model, temp_onnx_path) + return slimmed_model, transformed + return model, transformed + + +class CustomOpTransform(OnnxTransform): + """ + Transform to register custom operations and add their function protos to the ONNX model. + """ + + # Registry of custom operations + _custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func) + + @classmethod + def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any): + """Register a custom operation.""" + cls._custom_ops[op_name] = (func_class, onnxscript_func) + + @classmethod + def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: + """ + Apply custom op registration and add function protos to the model. + + :param model: The ONNX model to transform + :param opset_version: ONNX opset version for symbolic registration + :returns: Transformed model and success flag + """ + transformed = False + + # Register all custom op symbolic functions with torch.onnx + for op_name, (func_class, _) in cls._custom_ops.items(): + if hasattr(func_class, "symbolic"): + torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) + + # Add function protos for custom ops that are used in the model + used_protos = cls._get_function_protos_for_model(model) + + for proto in used_protos: + # Check if proto already exists to avoid duplicates + proto_name = proto.name + if not any(func.name == proto_name for func in model.functions): + model.functions.append(proto) + transformed = True + + return model, transformed + + @classmethod + def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]: + """Get function protos for custom ops that are actually used in the model.""" + used_protos = [] + + # Get all node op_types in the model + used_op_types = set() + for node in model.graph.node: + used_op_types.add(node.op_type) + + # Also check function calls + for func in model.functions: + for node in func.node: + used_op_types.add(node.op_type) + + # Check which custom ops are actually used + for op_name, (func_class, onnxscript_func) in cls._custom_ops.items(): + # Check if the custom op is referenced in the model + if cls._is_custom_op_used(model, op_name, used_op_types): + proto = onnxscript_func.to_function_proto() + used_protos.append(proto) + + return used_protos + + @classmethod + def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool: + """Check if a custom op is used in the model.""" + # Check if the op_name appears in node op_types + if op_name in used_op_types: + return True + + # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm") + custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" + if custom_op_pattern in used_op_types: + return True + + # Heuristic checks based on op type + if "RMSNorm" in op_name: + # Check if any RMSNorm-related ops are present + return any("RMSNorm" in op_type for op_type in used_op_types) + + if "Ctx" in op_name: + # Check if Gather/Scatter operations are present (indicating KV cache usage) + return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) + + return False + + +class RenameFunctionOutputsTransform(OnnxTransform): + """ + Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns. + """ + + @classmethod + def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: + """ + Rename function outputs in decoder layer nodes. + + :param model: The ONNX model to transform + :returns: Transformed model and boolean indicating whether transform was applied + """ + graph = model.graph + op_type_to_func_map = {func.name: func for func in model.functions} + decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] + transformed = False + model_graph_outputs = [val.name for val in model.graph.output] + layer_index = 0 + for node in graph.node: + if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): + func = op_type_to_func_map.get(node.op_type) + if func is None: + continue + + for i, out_name in enumerate(func.output): + if "_InternalRetainedState" in out_name: + transformed = True + tmp = node.output[i] + if "key" in out_name: + new_name = f"past_key.{layer_index}_RetainedState" + elif "value" in out_name: + new_name = f"past_value.{layer_index}_RetainedState" + node.output[i] = new_name + # Update graph output name if it exists + if tmp in model_graph_outputs: + model.graph.output[model_graph_outputs.index(tmp)].name = new_name + layer_index = layer_index + 1 + return model, transformed diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 853567be9..c713d4ee9 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- +import os from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple @@ -24,6 +25,30 @@ ) +def _get_invalid_idx_value(): + """ + Get the appropriate invalid index value for CtxGather operations. + + For ONNX export with functions, we use 0 to avoid INT32_MAX constants + that cause issues when functions are inlined at runtime. + + Returns: + int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise) + """ + if torch.onnx.is_in_onnx_export(): + # Check if ONNX functions are being used + use_onnx_functions = os.environ.get("QEFF_USE_ONNX_FUNCTIONS", "false").lower() == "true" + if use_onnx_functions: + # For ONNX functions: use 0 to avoid function inlining issues + return 0 + else: + # For regular ONNX export: use INT32_MAX as before + return torch.iinfo(torch.int32).max + else: + # For runtime: use 0 + return 0 + + class QEffDynamicLayer(DynamicLayer): def read_only(self, cache_kwargs): """ @@ -45,10 +70,7 @@ def read_only(self, cache_kwargs): gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = _get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) @@ -142,10 +164,7 @@ def update( gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = _get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: @@ -418,10 +437,8 @@ def update( ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = _get_invalid_idx_value() + print(f"value of INVALID IDX VALUE is {invalid_idx_value}") ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 20b7036fd..a611de0df 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -80,7 +80,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling: factor = config.rope_scaling["factor"] inv_freq /= factor self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..9da0e183c 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -27,7 +27,13 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import ( + CustomOpTransform, + FP16ClipTransform, + OnnxSlimTransform, + RenameFunctionOutputsTransform, + SplitTensorsTransform, +) from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( @@ -2111,7 +2117,13 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + _onnx_transforms = [ + FP16ClipTransform, + CustomOpTransform, + RenameFunctionOutputsTransform, + OnnxSlimTransform, + SplitTensorsTransform, + ] def __init__( self, @@ -2359,7 +2371,7 @@ def export(self, export_dir: Optional[str] = None) -> str: for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32)) dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes - output_names.append(f"past_{kv}.{i}_RetainedState") + output_names.append(f"past_{kv}.{i}_InternalRetainedState") else: # HACK: create common function for this including above if condition code @@ -2377,7 +2389,7 @@ def export(self, export_dir: Optional[str] = None) -> str: for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_{kv}.{i}_RetainedState") + output_names.append(f"past_{kv}.{i}_InternalRetainedState") if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..62a873b9e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -821,3 +821,29 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") return model, transformed + + +def get_decoder_layer_classes_for_export(model: nn.Module) -> set: + """ + Dynamically determine which DecoderLayer classes should be exported as functions + based on the model's architecture using the existing KVCacheTransform mapping. + """ + # Define patterns that identify decoder layer classes + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] + + # Get all QEff classes that are decoder layers from the existing mapping + decoder_layer_classes = set() + + for original_class, qeff_class in KVCacheTransform._module_mapping.items(): + # Check if the QEff class name contains decoder layer patterns + qeff_class_name = qeff_class.__name__ + if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): + decoder_layer_classes.add(qeff_class) + + # Filter to only include classes that are actually used in the current model + model_decoder_classes = set() + for module in model.modules(): + if module.__class__ in decoder_layer_classes: + model_decoder_classes.add(module.__class__) + + return model_decoder_classes diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 5f7a4db7b..5aab14520 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,7 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -84,7 +84,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100" diff --git a/QEfficient/utils/patches.py b/QEfficient/utils/patches.py new file mode 100644 index 000000000..a652bbb2a --- /dev/null +++ b/QEfficient/utils/patches.py @@ -0,0 +1,120 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Monkey patches for torch.onnx.utils to fix ONNX export issues.""" + +from typing import Collection, Set, Type, Union + +import torch +import torch.onnx.utils as onnx_utils +from torch import _C + + +def _setup_trace_module_map_patched( + model: Union[torch.nn.Module, torch.jit.ScriptModule], + export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]], +) -> Set[str]: + """Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch.""" + + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + # FIX: use empty dict to avoid type mismatch with _jit_pass_onnx_track_scope_attributes + # Observed in transformers v4.55 and above + onnx_attrs = {} + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + """ + Parse qualified variable name and return the unqualified version. + Pure numeric atoms are considered inadequate, so this function will look past them, + and start from the first non-numeric atom. + """ + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n)) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be " + "passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _get_module_attributes(module): + """Helper function to get module attributes safely.""" + import typing + + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def apply_torch_patches(): + """Apply all necessary torch patches for ONNX export.""" + # Monkey patch the function + onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched + + if hasattr(onnx_utils, "_get_module_attributes"): + onnx_utils._get_module_attributes = _get_module_attributes + + print("Applied torch ONNX export patches for export_modules_as_functions compatibility") + + +def is_patched(): + """Check if patches have been applied.""" + return onnx_utils._setup_trace_module_map == _setup_trace_module_map_patched