Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.patches import apply_torch_patches, is_patched

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient


# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

# Apply patches
# TODO: Find a better way to do this, this is temp. fix.
apply_torch_patches()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not enabling subfunction do we need to do the monkey patching?



def check_qaic_sdk():
"""Check if QAIC SDK is installed"""
Expand Down Expand Up @@ -70,6 +74,8 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"apply_torch_patches",
"is_patched",
]

else:
Expand Down
13 changes: 12 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
Expand Down Expand Up @@ -243,6 +246,11 @@ def _export(
input_names.append(param)

try:
# Initialize the registry with your custom ops
CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm)
CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter)
CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather)
decoder_layer_classes = get_decoder_layer_classes_for_export(self.model)
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
self.model,
Expand All @@ -252,6 +260,8 @@ def _export(
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
export_modules_as_functions=decoder_layer_classes,
do_constant_folding=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need it to be do_constant_folding=True and export_modules_as_functions by default if we are enabling it via env variable?

**export_kwargs,
)
logger.info("PyTorch export successful")
Expand All @@ -261,6 +271,7 @@ def _export(
model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
"onnx_base_dir": str(tmp_onnx_dir),
"temp_onnx_path": tmp_onnx_path,
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand Down
166 changes: 165 additions & 1 deletion QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#
# ----------------------------------------------------------------------------

from typing import Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import onnx
import onnxslim
import torch
from onnx import ModelProto, external_data_helper, numpy_helper


Expand Down Expand Up @@ -99,3 +102,164 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class OnnxSlimTransform(OnnxTransform):
"""
Applies onnx-slim transformations on the given ONNX graph.
"""

@classmethod
def apply(
cls,
model: ModelProto,
*,
onnx_base_dir: Optional[str] = None,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
:param temp_onnx_path: Path to save the slimmed ONNX model.
"""
transformed = False
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right?

temp_onnx_path = kwargs.get("temp_onnx_path", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it as a mandiatory argument? and onnx_base_dir is unused here

if not temp_onnx_path:
err_str = "temp_onnx_path is required for onnx-slim transform."
raise RuntimeError(err_str)
if onnx_slim_transform:
transformed = True
slimmed_model = onnxslim.slim(model)
onnx.save(slimmed_model, temp_onnx_path)
return slimmed_model, transformed
return model, transformed


class CustomOpTransform(OnnxTransform):
"""
Transform to register custom operations and add their function protos to the ONNX model.
"""

# Registry of custom operations
_custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func)

@classmethod
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any):
"""Register a custom operation."""
cls._custom_ops[op_name] = (func_class, onnxscript_func)

@classmethod
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
"""
Apply custom op registration and add function protos to the model.

:param model: The ONNX model to transform
:param opset_version: ONNX opset version for symbolic registration
:returns: Transformed model and success flag
"""
transformed = False

# Register all custom op symbolic functions with torch.onnx
for op_name, (func_class, _) in cls._custom_ops.items():
if hasattr(func_class, "symbolic"):
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)

# Add function protos for custom ops that are used in the model
used_protos = cls._get_function_protos_for_model(model)

for proto in used_protos:
# Check if proto already exists to avoid duplicates
proto_name = proto.name
if not any(func.name == proto_name for func in model.functions):
model.functions.append(proto)
transformed = True

return model, transformed

@classmethod
def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]:
"""Get function protos for custom ops that are actually used in the model."""
used_protos = []

# Get all node op_types in the model
used_op_types = set()
for node in model.graph.node:
used_op_types.add(node.op_type)

# Also check function calls
for func in model.functions:
for node in func.node:
used_op_types.add(node.op_type)

# Check which custom ops are actually used
for op_name, (func_class, onnxscript_func) in cls._custom_ops.items():
# Check if the custom op is referenced in the model
if cls._is_custom_op_used(model, op_name, used_op_types):
proto = onnxscript_func.to_function_proto()
used_protos.append(proto)

return used_protos

@classmethod
def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool:
"""Check if a custom op is used in the model."""
# Check if the op_name appears in node op_types
if op_name in used_op_types:
return True

# Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}"
if custom_op_pattern in used_op_types:
return True

# Heuristic checks based on op type
if "RMSNorm" in op_name:
# Check if any RMSNorm-related ops are present
return any("RMSNorm" in op_type for op_type in used_op_types)

if "Ctx" in op_name:
# Check if Gather/Scatter operations are present (indicating KV cache usage)
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)

return False


class RenameFunctionOutputsTransform(OnnxTransform):
"""
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
"""

@classmethod
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
Rename function outputs in decoder layer nodes.

:param model: The ONNX model to transform
:returns: Transformed model and boolean indicating whether transform was applied
"""
graph = model.graph
op_type_to_func_map = {func.name: func for func in model.functions}
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
transformed = False
model_graph_outputs = [val.name for val in model.graph.output]
layer_index = 0
for node in graph.node:
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
func = op_type_to_func_map.get(node.op_type)
if func is None:
continue

for i, out_name in enumerate(func.output):
if "_InternalRetainedState" in out_name:
transformed = True
tmp = node.output[i]
if "key" in out_name:
new_name = f"past_key.{layer_index}_RetainedState"
elif "value" in out_name:
new_name = f"past_value.{layer_index}_RetainedState"
node.output[i] = new_name
# Update graph output name if it exists
if tmp in model_graph_outputs:
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
layer_index = layer_index + 1
return model, transformed
41 changes: 29 additions & 12 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# -----------------------------------------------------------------------------


import os
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -24,6 +25,30 @@
)


def _get_invalid_idx_value():
"""
Get the appropriate invalid index value for CtxGather operations.

For ONNX export with functions, we use 0 to avoid INT32_MAX constants
that cause issues when functions are inlined at runtime.

Returns:
int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise)
"""
if torch.onnx.is_in_onnx_export():
# Check if ONNX functions are being used
use_onnx_functions = os.environ.get("QEFF_USE_ONNX_FUNCTIONS", "false").lower() == "true"
if use_onnx_functions:
# For ONNX functions: use 0 to avoid function inlining issues
return 0
else:
# For regular ONNX export: use INT32_MAX as before
return torch.iinfo(torch.int32).max
else:
# For runtime: use 0
return 0


class QEffDynamicLayer(DynamicLayer):
def read_only(self, cache_kwargs):
"""
Expand All @@ -45,10 +70,7 @@ def read_only(self, cache_kwargs):
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = _get_invalid_idx_value()

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

Expand Down Expand Up @@ -142,10 +164,7 @@ def update(
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = _get_invalid_idx_value()

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
Expand Down Expand Up @@ -418,10 +437,8 @@ def update(
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = _get_invalid_idx_value()
print(f"value of INVALID IDX VALUE is {invalid_idx_value}")
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))

if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling:
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change part of ONNX Sub Functions?

factor = config.rope_scaling["factor"]
inv_freq /= factor
self.register_buffer("inv_freq", inv_freq, persistent=False)
Expand Down
20 changes: 16 additions & 4 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@

import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import (
CustomOpTransform,
FP16ClipTransform,
OnnxSlimTransform,
RenameFunctionOutputsTransform,
SplitTensorsTransform,
)
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.text_generation_inference import (
Expand Down Expand Up @@ -2111,7 +2117,13 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
SplitGateUpWeightsTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [
FP16ClipTransform,
CustomOpTransform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to apply the CustomOpTransform again after export?

RenameFunctionOutputsTransform,
OnnxSlimTransform,
SplitTensorsTransform,
]

def __init__(
self,
Expand Down Expand Up @@ -2359,7 +2371,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
output_names.append(f"past_{kv}.{i}_RetainedState")
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution


else:
# HACK: create common function for this including above if condition code
Expand All @@ -2377,7 +2389,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
output_names.append(f"past_{kv}.{i}_RetainedState")
output_names.append(f"past_{kv}.{i}_InternalRetainedState")

if self.continuous_batching:
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
Expand Down
Loading
Loading