-
Notifications
You must be signed in to change notification settings - Fork 59
WIP: Feat: Add ONNX Sub Functions Export Feature #613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Auto-detect decoder layers for export_modules_as_functions based on model type - Add CustomOpTransform to dynamically register and include custom ops (CustomRMSNorm, CtxGather, CtxScatter) - Fix invalid INT32_MAX indices in ONNX runtime by replacing with 0 - Support ONNX functions export via QEFF_USE_ONNX_FUNCTIONS env var - Handle rope_scaling None values gracefully for Gemma3 Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Vinayak Baddi <quic_vbaddi@quicinc.com>
Signed-off-by: Vinayak Baddi <quic_vbaddi@quicinc.com>
Signed-off-by: Vinayak Baddi <quic_vbaddi@quicinc.com>
| """ | ||
| transformed = False | ||
| onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) | ||
| temp_onnx_path = kwargs.get("temp_onnx_path", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make it as a mandiatory argument? and onnx_base_dir is unused here
| :param temp_onnx_path: Path to save the slimmed ONNX model. | ||
| """ | ||
| transformed = False | ||
| onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right?
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | ||
|
|
||
| if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling: | ||
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change part of ONNX Sub Functions?
| example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32)) | ||
| dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes | ||
| output_names.append(f"past_{kv}.{i}_RetainedState") | ||
| output_names.append(f"past_{kv}.{i}_InternalRetainedState") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution
| ONNX_EXPORT_EXAMPLE_FBS = 4 | ||
| ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep | ||
| ONNX_EXPORT_OPSET = 13 | ||
| ONNX_EXPORT_OPSET = 17 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some test on opset 17 is still ongoing @quic-hemagnih are we good to merge opset 17 changes?
|
|
||
| # Apply patches | ||
| # TODO: Find a better way to do this, this is temp. fix. | ||
| apply_torch_patches() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not enabling subfunction do we need to do the monkey patching?
| dynamic_axes=dynamic_axes, | ||
| opset_version=constants.ONNX_EXPORT_OPSET, | ||
| export_modules_as_functions=decoder_layer_classes, | ||
| do_constant_folding=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need it to be do_constant_folding=True and export_modules_as_functions by default if we are enabling it via env variable?
| _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] | ||
| _onnx_transforms = [ | ||
| FP16ClipTransform, | ||
| CustomOpTransform, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to apply the CustomOpTransform again after export?
This PR introduces support for exporting ONNX modules as functions, enabling more efficient model compilation and execution on hardware.
Changes
QEFF_USE_ONNX_FUNCTIONSto control ONNX function export behaviorEnable ONNX Functions Export
Set the environment variable before running inference:
export QEFF_USE_ONNX_FUNCTIONS=trueExport and Execute with ONNX Functions
Backward Compatibility
This feature is opt-in and requires explicit environment variable. Existing workflows remain unaffected when the flag is disabled.