88import mock_keras2onnx
99from mock_keras2onnx .proto import keras , is_keras_older_than
1010from mock_keras2onnx .proto .tfcompat import is_tf2
11+ from packaging .version import Version
12+ from tf2onnx .keras2onnx_api import convert_keras
1113import time
1214import json
1315import urllib
1416
17+
18+ # Mapping opset to ONNXRuntime version.
19+ ORT_OPSET_VERSION = {
20+ "1.6.0" : 13 , "1.7.0" : 13 , "1.8.0" : 14 , "1.9.0" : 15 , "1.10.0" : 15 , "1.11.0" : 16 , "1.12.0" : 17
21+ }
22+
1523working_path = os .path .abspath (os .path .dirname (__file__ ))
1624tmp_path = os .path .join (working_path , 'temp' )
1725test_level_0 = True
@@ -299,3 +307,26 @@ def is_bloburl_access(url):
299307 return response .getcode () == 200
300308 except urllib .error .URLError :
301309 return False
310+
311+
312+ def get_max_opset_supported_by_ort ():
313+ try :
314+ import onnxruntime as ort
315+ ort_ver = Version (ort .__version__ ).base_version
316+
317+ if ort_ver in ORT_OPSET_VERSION .keys ():
318+ return ORT_OPSET_VERSION [ort_ver ]
319+ else :
320+ print ("Given onnxruntime version doesn't exist in ORT_OPSET_VERSION: {}" .format (ort_ver ))
321+ return None
322+ except ImportError :
323+ return None
324+
325+
326+ def convert_keras_for_test (model , name = None , target_opset = None , ** kwargs ):
327+ if target_opset is None :
328+ target_opset = get_max_opset_supported_by_ort ()
329+
330+ print ("Trying to run test with opset version: {}" .format (target_opset ))
331+
332+ return convert_keras (model = model , name = name , target_opset = target_opset , ** kwargs )
0 commit comments