Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/onnx_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Model Optimizer enables highly performant quantization formats including NVFP4,
| Pre-Requisites | Required & optional packages to use this technique | [Link](#pre-requisites) | |
| Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | [Link](#getting-started) | [docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/_onnx_quantization.html) |
| Support Matrix | View the ONNX export supported LLM models | [Link](#onnx-export-supported-llm-models) | |
| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-example-for-mxfp8-int4-or-nvfp4-precision) | |
| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-export-example) | |
| Advanced Features | Examples demonstrating use advanced ONNX quantization features | [Link](#advanced-features) | |
| Pre-Quantized Checkpoints | Ready to deploy Hugging Face pre-quantized checkpoints | [Link](#pre-quantized-checkpoints) | |
| Resources | Extra links to relevant resources | [Link](#resources) | |
Expand Down Expand Up @@ -80,7 +80,7 @@ python image_prep.py \

The model can be quantized as an FP8, INT8 or INT4 model using either the CLI or Python API. For FP8 and INT8 quantization, you have a choice between `max` and `entropy` calibration algorithms. For INT4 quantization, [awq_clip](https://arxiv.org/abs/2306.00978) or [rtn_dq](https://ar5iv.labs.arxiv.org/html/2301.12017) algorithms can be chosen.

> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-example-for-mxfp8-int4-or-nvfp4-precision).*
> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-export-example).*

> *Minimum opset requirements: int8 (13+), fp8 (21+), int4 (21+). ModelOpt will automatically upgrade lower opset versions to meet these requirements.*

Expand Down Expand Up @@ -129,9 +129,9 @@ The top5 accuracy of the model is <accuracy score between 0-100%>
Inference latency of the model is <X> ms
```

## Torch quantization to ONNX example for MXFP8, INT4 or NVFP4 precision
## Torch quantization to ONNX export example

This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model using MXFP8, INT4 or NVFP4 precision formats, and then export it to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export.
This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model for various precision formats followed by export to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export.

> *Opset 20 is used to export the torch models to ONNX.*

Expand All @@ -148,7 +148,7 @@ This example demonstrates how to quantize a [timm](https://github.com/huggingfac
```bash
python torch_quant_to_onnx.py \
--timm_model_name=vit_base_patch16_224 \
--quantize_mode=<mxfp8|nvfp4|int4_awq> \
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
--onnx_save_path=<path to save the exported ONNX model>
```

Expand Down
3 changes: 2 additions & 1 deletion examples/onnx_ptq/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def evaluate_accuracy(

# Calculate accuracy
outputs = outputs[0] if isinstance(outputs, list) else outputs.data

labels_size = labels.size(0)
outputs = outputs[:labels_size]

total += labels_size

labels = labels.to(outputs.device)
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx_ptq/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def main():
)
print(f"Quantized Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%")

if args.quantize_mode in ["fp8", "int8", "auto"]:
if args.quantize_mode in ["auto"]:
print(
f"The selected quantization mode {args.quantize_mode} is not supported for ONNX export yet."
)
Expand Down
15 changes: 15 additions & 0 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,21 @@ def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
return onnx_model


def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Replace zero scale values with smallest nonzero fp16 value in the ONNX model."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you document in what condition do we need to call this method here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I had already set the MR to auto merge. Will add this in a follow up MR.

graph = onnx_model.graph
fp16_smallest_nonzero = np.float16(6e-08)
scale_nodes = [node.input[1] for node in graph.node if node.op_type == "QuantizeLinear"]
for node in graph.node:
if node.op_type == "Constant" and node.output[0] in scale_nodes:
for attr in node.attribute:
if attr.name == "value":
tensor = numpy_helper.to_array(attr.t)
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor)
attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name))
return onnx_model


def _cast_initializer_to_dtype(
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
):
Expand Down
30 changes: 30 additions & 0 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
qdq_to_dq,
quantize_weights_to_int4,
quantize_weights_to_mxfp8,
replace_zero_scale_with_smallest_nonzero,
)
from modelopt.onnx.utils import (
get_input_names,
Expand Down Expand Up @@ -336,6 +337,32 @@ def is_mxfp8_quantized(model: nn.Module) -> bool:
return False


def is_int8_quantized(model: nn.Module) -> bool:
"""Check if the model is quantized in INT8 mode."""
for _, module in model.named_modules():
if (
hasattr(module, "weight_quantizer")
and hasattr(module, "input_quantizer")
and module.weight_quantizer._num_bits == 8
and module.input_quantizer._num_bits == 8
):
return True
return False


def is_fp8_quantized(model: nn.Module) -> bool:
"""Check if the model is quantized in FP8 mode."""
for _, module in model.named_modules():
if (
hasattr(module, "weight_quantizer")
and hasattr(module, "input_quantizer")
and module.weight_quantizer._num_bits == (4, 3)
and module.input_quantizer._num_bits == (4, 3)
):
return True
return False


def get_onnx_bytes_and_metadata(
model: nn.Module,
dummy_input: Any | tuple,
Expand Down Expand Up @@ -510,6 +537,9 @@ def get_onnx_bytes_and_metadata(
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
)

# TensorRT expects all scales to be postive
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)

# If the onnx model contains external data store the external tensors in one file and save the onnx model
if has_external_data(onnx_save_path):
tensor_paths = get_external_tensor_paths(onnx_path)
Expand Down
Loading