-
Notifications
You must be signed in to change notification settings - Fork 59
WIP: Feat: Add ONNX Sub Functions Export Feature #613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dfabf37
2cb1708
c16a9eb
02eaaa8
1fce1d6
7f1d431
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,10 +18,13 @@ | |
| import onnx | ||
| import torch | ||
|
|
||
| from QEfficient.base.onnx_transforms import OnnxTransform | ||
| from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform | ||
| from QEfficient.base.pytorch_transforms import PytorchTransform | ||
| from QEfficient.compile.qnn_compiler import compile as qnn_compile | ||
| from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc | ||
| from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc | ||
| from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
| from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export | ||
| from QEfficient.utils import ( | ||
| constants, | ||
| create_json, | ||
|
|
@@ -243,6 +246,11 @@ def _export( | |
| input_names.append(param) | ||
|
|
||
| try: | ||
| # Initialize the registry with your custom ops | ||
| CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm) | ||
| CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter) | ||
| CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather) | ||
| decoder_layer_classes = get_decoder_layer_classes_for_export(self.model) | ||
| export_kwargs = {} if export_kwargs is None else export_kwargs | ||
| torch.onnx.export( | ||
| self.model, | ||
|
|
@@ -252,6 +260,8 @@ def _export( | |
| output_names=output_names, | ||
| dynamic_axes=dynamic_axes, | ||
| opset_version=constants.ONNX_EXPORT_OPSET, | ||
| export_modules_as_functions=decoder_layer_classes, | ||
| do_constant_folding=True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need it to be do_constant_folding=True and export_modules_as_functions by default if we are enabling it via env variable? |
||
| **export_kwargs, | ||
| ) | ||
| logger.info("PyTorch export successful") | ||
|
|
@@ -261,6 +271,7 @@ def _export( | |
| model = onnx.load(tmp_onnx_path, load_external_data=False) | ||
| transform_kwargs = { | ||
| "onnx_base_dir": str(tmp_onnx_dir), | ||
| "temp_onnx_path": tmp_onnx_path, | ||
| "model_name": self.model_name, | ||
| } | ||
| if onnx_transform_kwargs is not None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,9 +5,12 @@ | |
| # | ||
| # ---------------------------------------------------------------------------- | ||
|
|
||
| from typing import Optional, Tuple | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import onnx | ||
| import onnxslim | ||
| import torch | ||
| from onnx import ModelProto, external_data_helper, numpy_helper | ||
|
|
||
|
|
||
|
|
@@ -99,3 +102,164 @@ def apply( | |
| current_file_size = tsize | ||
| external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") | ||
| return model, transformed | ||
|
|
||
|
|
||
| class OnnxSlimTransform(OnnxTransform): | ||
| """ | ||
| Applies onnx-slim transformations on the given ONNX graph. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def apply( | ||
| cls, | ||
| model: ModelProto, | ||
| *, | ||
| onnx_base_dir: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Tuple[ModelProto, bool]: | ||
| """ | ||
| :param enable_onnx_slim_transform: If True, applies onnx-slim transformations. | ||
| :param temp_onnx_path: Path to save the slimmed ONNX model. | ||
| """ | ||
| transformed = False | ||
| onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right? |
||
| temp_onnx_path = kwargs.get("temp_onnx_path", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make it as a mandiatory argument? and onnx_base_dir is unused here |
||
| if not temp_onnx_path: | ||
| err_str = "temp_onnx_path is required for onnx-slim transform." | ||
| raise RuntimeError(err_str) | ||
| if onnx_slim_transform: | ||
| transformed = True | ||
| slimmed_model = onnxslim.slim(model) | ||
| onnx.save(slimmed_model, temp_onnx_path) | ||
| return slimmed_model, transformed | ||
| return model, transformed | ||
|
|
||
|
|
||
| class CustomOpTransform(OnnxTransform): | ||
| """ | ||
| Transform to register custom operations and add their function protos to the ONNX model. | ||
| """ | ||
|
|
||
| # Registry of custom operations | ||
| _custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func) | ||
|
|
||
| @classmethod | ||
| def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any): | ||
| """Register a custom operation.""" | ||
| cls._custom_ops[op_name] = (func_class, onnxscript_func) | ||
|
|
||
| @classmethod | ||
| def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: | ||
| """ | ||
| Apply custom op registration and add function protos to the model. | ||
|
|
||
| :param model: The ONNX model to transform | ||
| :param opset_version: ONNX opset version for symbolic registration | ||
| :returns: Transformed model and success flag | ||
| """ | ||
| transformed = False | ||
|
|
||
| # Register all custom op symbolic functions with torch.onnx | ||
| for op_name, (func_class, _) in cls._custom_ops.items(): | ||
| if hasattr(func_class, "symbolic"): | ||
| torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) | ||
|
|
||
| # Add function protos for custom ops that are used in the model | ||
| used_protos = cls._get_function_protos_for_model(model) | ||
|
|
||
| for proto in used_protos: | ||
| # Check if proto already exists to avoid duplicates | ||
| proto_name = proto.name | ||
| if not any(func.name == proto_name for func in model.functions): | ||
| model.functions.append(proto) | ||
| transformed = True | ||
|
|
||
| return model, transformed | ||
|
|
||
| @classmethod | ||
| def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]: | ||
| """Get function protos for custom ops that are actually used in the model.""" | ||
| used_protos = [] | ||
|
|
||
| # Get all node op_types in the model | ||
| used_op_types = set() | ||
| for node in model.graph.node: | ||
| used_op_types.add(node.op_type) | ||
|
|
||
| # Also check function calls | ||
| for func in model.functions: | ||
| for node in func.node: | ||
| used_op_types.add(node.op_type) | ||
|
|
||
| # Check which custom ops are actually used | ||
| for op_name, (func_class, onnxscript_func) in cls._custom_ops.items(): | ||
| # Check if the custom op is referenced in the model | ||
| if cls._is_custom_op_used(model, op_name, used_op_types): | ||
| proto = onnxscript_func.to_function_proto() | ||
| used_protos.append(proto) | ||
|
|
||
| return used_protos | ||
|
|
||
| @classmethod | ||
| def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool: | ||
| """Check if a custom op is used in the model.""" | ||
| # Check if the op_name appears in node op_types | ||
| if op_name in used_op_types: | ||
| return True | ||
|
|
||
| # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm") | ||
| custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" | ||
| if custom_op_pattern in used_op_types: | ||
| return True | ||
|
|
||
| # Heuristic checks based on op type | ||
| if "RMSNorm" in op_name: | ||
| # Check if any RMSNorm-related ops are present | ||
| return any("RMSNorm" in op_type for op_type in used_op_types) | ||
|
|
||
| if "Ctx" in op_name: | ||
| # Check if Gather/Scatter operations are present (indicating KV cache usage) | ||
| return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| class RenameFunctionOutputsTransform(OnnxTransform): | ||
| """ | ||
| Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: | ||
| """ | ||
| Rename function outputs in decoder layer nodes. | ||
|
|
||
| :param model: The ONNX model to transform | ||
| :returns: Transformed model and boolean indicating whether transform was applied | ||
| """ | ||
| graph = model.graph | ||
| op_type_to_func_map = {func.name: func for func in model.functions} | ||
| decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] | ||
| transformed = False | ||
| model_graph_outputs = [val.name for val in model.graph.output] | ||
| layer_index = 0 | ||
| for node in graph.node: | ||
| if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): | ||
| func = op_type_to_func_map.get(node.op_type) | ||
| if func is None: | ||
| continue | ||
|
|
||
| for i, out_name in enumerate(func.output): | ||
| if "_InternalRetainedState" in out_name: | ||
| transformed = True | ||
| tmp = node.output[i] | ||
| if "key" in out_name: | ||
| new_name = f"past_key.{layer_index}_RetainedState" | ||
| elif "value" in out_name: | ||
| new_name = f"past_value.{layer_index}_RetainedState" | ||
| node.output[i] = new_name | ||
| # Update graph output name if it exists | ||
| if tmp in model_graph_outputs: | ||
| model.graph.output[model_graph_outputs.index(tmp)].name = new_name | ||
| layer_index = layer_index + 1 | ||
| return model, transformed | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,7 +80,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device | |
| self.base = base | ||
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | ||
|
|
||
| if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling: | ||
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change part of ONNX Sub Functions? |
||
| factor = config.rope_scaling["factor"] | ||
| inv_freq /= factor | ||
| self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,13 @@ | |
|
|
||
| import QEfficient | ||
| from QEfficient.base.modeling_qeff import QEFFBaseModel | ||
| from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform | ||
| from QEfficient.base.onnx_transforms import ( | ||
| CustomOpTransform, | ||
| FP16ClipTransform, | ||
| OnnxSlimTransform, | ||
| RenameFunctionOutputsTransform, | ||
| SplitTensorsTransform, | ||
| ) | ||
| from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform | ||
| from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
| from QEfficient.generation.text_generation_inference import ( | ||
|
|
@@ -2111,7 +2117,13 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): | |
| SplitGateUpWeightsTransform, | ||
| KVCacheExternalModuleMapperTransform, | ||
| ] | ||
| _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] | ||
| _onnx_transforms = [ | ||
| FP16ClipTransform, | ||
| CustomOpTransform, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to apply the CustomOpTransform again after export? |
||
| RenameFunctionOutputsTransform, | ||
| OnnxSlimTransform, | ||
| SplitTensorsTransform, | ||
| ] | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -2359,7 +2371,7 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
| for kv in ["key", "value"]: | ||
| example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32)) | ||
| dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes | ||
| output_names.append(f"past_{kv}.{i}_RetainedState") | ||
| output_names.append(f"past_{kv}.{i}_InternalRetainedState") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution |
||
|
|
||
| else: | ||
| # HACK: create common function for this including above if condition code | ||
|
|
@@ -2377,7 +2389,7 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
| for kv in ["key", "value"]: | ||
| example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) | ||
| dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] | ||
| output_names.append(f"past_{kv}.{i}_RetainedState") | ||
| output_names.append(f"past_{kv}.{i}_InternalRetainedState") | ||
|
|
||
| if self.continuous_batching: | ||
| example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not enabling subfunction do we need to do the monkey patching?