Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.patches import apply_torch_patches, is_patched

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
Expand All @@ -22,6 +23,10 @@
# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

# Apply patches
# TODO: Find a better way to do this, this is temp. fix.
apply_torch_patches()


def check_qaic_sdk():
"""Check if QAIC SDK is installed"""
Expand Down Expand Up @@ -70,6 +75,10 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"apply_torch_patches",
"is_patched",
"apply_torch_patches",
"is_patched",
]

else:
Expand Down
17 changes: 14 additions & 3 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
Expand Down Expand Up @@ -243,7 +246,13 @@ def _export(
input_names.append(param)

try:
# Initialize the registry with your custom ops
CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm)
CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter)
CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather)
decoder_layer_classes = get_decoder_layer_classes_for_export(self.model)
export_kwargs = {} if export_kwargs is None else export_kwargs

torch.onnx.export(
self.model,
(example_inputs,),
Expand All @@ -252,15 +261,18 @@ def _export(
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
export_modules_as_functions=decoder_layer_classes,
do_constant_folding=True,
**export_kwargs,
)
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)

model = onnx.load(tmp_onnx_path, load_external_data=False)

transform_kwargs = {
"onnx_base_dir": str(tmp_onnx_dir),
"temp_onnx_path": tmp_onnx_path,
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand All @@ -273,7 +285,6 @@ def _export(
onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names()))
)
logger.info("ONNX transforms applied")

onnx.save(model, onnx_path)
logger.info("Transformed ONNX saved")

Expand Down
167 changes: 166 additions & 1 deletion QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#
# ----------------------------------------------------------------------------

from typing import Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import onnx
import onnxslim
import torch
from onnx import ModelProto, external_data_helper, numpy_helper


Expand Down Expand Up @@ -99,3 +102,165 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class OnnxSlimTransform(OnnxTransform):
"""
Applies onnx-slim transformations on the given ONNX graph.
"""

@classmethod
def apply(
cls,
model: ModelProto,
*,
onnx_base_dir: Optional[str] = None,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
:param temp_onnx_path: Path to save the slimmed ONNX model.
"""
transformed = False
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
temp_onnx_path = kwargs.get("temp_onnx_path", None)
if not temp_onnx_path:
err_str = "temp_onnx_path is required for onnx-slim transform."
raise RuntimeError(err_str)
if onnx_slim_transform:
transformed = True
slimmed_model = onnxslim.slim(model)
onnx.save(slimmed_model, temp_onnx_path)
return slimmed_model, transformed
return model, transformed


class CustomOpTransform(OnnxTransform):
"""
Transform to register custom operations and add their function protos to the ONNX model.
"""

# Registry of custom operations
_custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func)

@classmethod
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any):
"""Register a custom operation."""
cls._custom_ops[op_name] = (func_class, onnxscript_func)

@classmethod
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
"""
Apply custom op registration and add function protos to the model.

:param model: The ONNX model to transform
:param opset_version: ONNX opset version for symbolic registration
:returns: Transformed model and success flag
"""
transformed = False

# Register all custom op symbolic functions with torch.onnx
for op_name, (func_class, _) in cls._custom_ops.items():
if hasattr(func_class, "symbolic"):
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)

# Add function protos for custom ops that are used in the model
used_protos = cls._get_function_protos_for_model(model)

for proto in used_protos:
# Check if proto already exists to avoid duplicates
proto_name = proto.name
if not any(func.name == proto_name for func in model.functions):
model.functions.append(proto)
transformed = True

return model, transformed

@classmethod
def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]:
"""Get function protos for custom ops that are actually used in the model."""
used_protos = []

# Get all node op_types in the model
used_op_types = set()
for node in model.graph.node:
used_op_types.add(node.op_type)

# Also check function calls
for func in model.functions:
for node in func.node:
used_op_types.add(node.op_type)

# Check which custom ops are actually used
for op_name, (func_class, onnxscript_func) in cls._custom_ops.items():
# Check if the custom op is referenced in the model
if cls._is_custom_op_used(model, op_name, used_op_types):
proto = onnxscript_func.to_function_proto()
used_protos.append(proto)

return used_protos

@classmethod
def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool:
"""Check if a custom op is used in the model."""
# Check if the op_name appears in node op_types
if op_name in used_op_types:
return True

# Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}"
if custom_op_pattern in used_op_types:
return True

# Heuristic checks based on op type
if "RMSNorm" in op_name:
# Check if any RMSNorm-related ops are present
return any("RMSNorm" in op_type for op_type in used_op_types)

if "Ctx" in op_name:
# Check if Gather/Scatter operations are present (indicating KV cache usage)
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)

return False


class RenameFunctionOutputsTransform(OnnxTransform):
"""
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
"""

@classmethod
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
Rename function outputs in decoder layer nodes.

:param model: The ONNX model to transform
:returns: Transformed model and boolean indicating whether transform was applied
"""
graph = model.graph
op_type_to_func_map = {func.name: func for func in model.functions}
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
transformed = False
model_graph_outputs = [val.name for val in model.graph.output]
layer_index = 0
for node in graph.node:
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
func = op_type_to_func_map.get(node.op_type)
if func is None:
continue

for i, out_name in enumerate(func.output):
if "_InternalRetainedState" in out_name:
transformed = True
tmp = node.output[i]
if "key" in out_name:
new_name = f"past_key.{layer_index}_RetainedState"
elif "value" in out_name:
new_name = f"past_value.{layer_index}_RetainedState"
# new_name = func.output[i].replace("Internal", "")
node.output[i] = new_name
# Update graph output name if it exists
if tmp in model_graph_outputs:
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
layer_index = layer_index + 1
return model, transformed
78 changes: 63 additions & 15 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:

class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model
Split fused Gate+Up weights and copy into the model.
Handles both standard MoE models and GptOss models.

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]

Handles both interleaved weights (GptOss) and concatenated weights (standard MoE).
Also handles bias terms when present.
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()

for layer_idx in range(num_layers):
# Determine if this is a GptOss model or standard MoE model
is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp")

# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
if is_gpt_oss:
prefix = f"model.layers.{layer_idx}.mlp.experts."
experts = model_tmp.model.layers[layer_idx].mlp.experts
else:
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
experts = model_tmp.model.layers[layer_idx].feed_forward.experts

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
# Check if we have bias terms (GptOss case)
has_bias = fused_key + "_bias" in sd
if has_bias:
fused_bias_key = fused_key + "_bias"
gate_bias_key = gate_key + "_bias"
up_bias_key = up_key + "_bias"

# ---- split weights based on model type ----------------------
fused = sd[fused_key] # [E, H, 2I]
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
if is_gpt_oss:
# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
gate = fused[..., ::2] # [E, H, I] - even indices
up = fused[..., 1::2] # [E, H, I] - odd indices
else:
# For standard MoE, gate/up are concatenated: [gate, up]
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

# Copy weights to model
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# Handle bias if present
if has_bias:
fused_bias = sd[fused_bias_key] # [E, 2I]

if is_gpt_oss:
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices
else:
ffn_dim = fused_bias.shape[-1] // 2
gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1)

experts.gate_proj_bias.data.copy_(gate_bias)
experts.up_proj_bias.data.copy_(up_bias)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if has_bias:
sd[gate_bias_key] = gate_bias
sd[up_bias_key] = up_bias

# Delete fused keys
if delete_fused_key:
del sd[fused_key]
if has_bias:
del sd[fused_bias_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp

return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
# Keep the existing list of supported models
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}
Loading
Loading