@@ -28,6 +28,7 @@ def onnx_export(
2828 dynamic_size : bool = False ,
2929 aten_fallback : bool = False ,
3030 keep_initializers : Optional [bool ] = None ,
31+ use_dynamo : bool = False ,
3132 input_names : List [str ] = None ,
3233 output_names : List [str ] = None ,
3334):
@@ -53,7 +54,8 @@ def onnx_export(
5354 # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
5455 # issues in the tracing of the dynamic padding or errors attempting to export the model after jit
5556 # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
56- original_out = model (example_input )
57+ with torch .no_grad ():
58+ original_out = model (example_input )
5759
5860 input_names = input_names or ["input0" ]
5961 output_names = output_names or ["output0" ]
@@ -68,28 +70,40 @@ def onnx_export(
6870 else :
6971 export_type = torch .onnx .OperatorExportTypes .ONNX
7072
71- torch_out = torch .onnx ._export (
72- model ,
73- example_input ,
74- output_file ,
75- training = training_mode ,
76- export_params = True ,
77- verbose = verbose ,
78- input_names = input_names ,
79- output_names = output_names ,
80- keep_initializers_as_inputs = keep_initializers ,
81- dynamic_axes = dynamic_axes ,
82- opset_version = opset ,
83- operator_export_type = export_type
84- )
73+ if use_dynamo :
74+ export_options = torch .onnx .ExportOptions (dynamic_shapes = dynamic_size )
75+ export_output = torch .onnx .dynamo_export (
76+ model ,
77+ example_input ,
78+ export_options = export_options ,
79+ )
80+ export_output .save (output_file )
81+ torch_out = None
82+ else :
83+ torch_out = torch .onnx ._export (
84+ model ,
85+ example_input ,
86+ output_file ,
87+ training = training_mode ,
88+ export_params = True ,
89+ verbose = verbose ,
90+ input_names = input_names ,
91+ output_names = output_names ,
92+ keep_initializers_as_inputs = keep_initializers ,
93+ dynamic_axes = dynamic_axes ,
94+ opset_version = opset ,
95+ operator_export_type = export_type
96+ )
8597
8698 if check :
8799 onnx_model = onnx .load (output_file )
88100 onnx .checker .check_model (onnx_model , full_check = True ) # assuming throw on error
89101 if check_forward and not training :
90102 import numpy as np
91103 onnx_out = onnx_forward (output_file , example_input )
92- np .testing .assert_almost_equal (torch_out .data .numpy (), onnx_out , decimal = 3 )
93- np .testing .assert_almost_equal (original_out .data .numpy (), torch_out .data .numpy (), decimal = 5 )
94-
104+ if torch_out is not None :
105+ np .testing .assert_almost_equal (torch_out .numpy (), onnx_out , decimal = 3 )
106+ np .testing .assert_almost_equal (original_out .numpy (), torch_out .numpy (), decimal = 5 )
107+ else :
108+ np .testing .assert_almost_equal (original_out .numpy (), onnx_out , decimal = 3 )
95109
0 commit comments