|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | import logging |
| 9 | +import tempfile |
9 | 10 | from pathlib import Path |
10 | 11 |
|
11 | 12 | import onnx |
12 | 13 |
|
13 | | -from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed |
| 14 | +from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed, optimize_model |
14 | 15 | from ....tools.remove_initializer_from_input import remove_initializer_from_input |
15 | 16 | from ...fusions import FusionGelu, FusionLayerNormalization |
16 | 17 | from ...onnx_model import ONNXModel |
17 | | -from ...quant_utils import save_and_reload_model_with_shape_infer |
18 | 18 | from .fusion_lpnorm import FusionLpNormalization |
19 | 19 | from .fusion_spacetodepth import FusionSpaceToDepth |
20 | 20 |
|
@@ -93,7 +93,7 @@ def qnn_preprocess_model( |
93 | 93 | """ |
94 | 94 | modified = False |
95 | 95 | model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) |
96 | | - model = save_and_reload_model_with_shape_infer(model) |
| 96 | + model = save_and_reload_optimize_model(model, shape_infer=True) |
97 | 97 | onnx_model = ONNXModel(model) |
98 | 98 |
|
99 | 99 | # Optionally, fix the dynamic input shapes. |
@@ -178,6 +178,24 @@ def qnn_preprocess_model( |
178 | 178 | return modified |
179 | 179 |
|
180 | 180 |
|
| 181 | +def save_and_reload_optimize_model(model: onnx.ModelProto, shape_infer: bool) -> onnx.ModelProto: |
| 182 | + with tempfile.TemporaryDirectory(prefix="ort.qnn_preproc.") as qnn_preproc_tmp_dir: |
| 183 | + model_in_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_input.onnx") |
| 184 | + onnx.save_model(model, model_in_path, save_as_external_data=True) |
| 185 | + if shape_infer: |
| 186 | + model_infer_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_infer.onnx") |
| 187 | + onnx.shape_inference.infer_shapes_path(str(model_in_path), str(model_infer_path)) |
| 188 | + model_in_path = model_infer_path |
| 189 | + model_out_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_output.onnx") |
| 190 | + optimize_model(model_in_path, model_out_path) |
| 191 | + ret_model = onnx.load_model(model_out_path) |
| 192 | + ret_metaprops = {"onnx.infer": "onnxruntime.tools.qnn.preprocess"} |
| 193 | + if ret_model.metadata_props: |
| 194 | + ret_metaprops.update(ret_model.metadata_props) |
| 195 | + onnx.helper.set_model_props(ret_model, ret_metaprops) |
| 196 | + return ret_model |
| 197 | + |
| 198 | + |
181 | 199 | class InputOutputNameMap: |
182 | 200 | def __init__( |
183 | 201 | self, |
|
0 commit comments