Skip to content

Commit f44b314

Browse files
authored
Allow opset_version to be set explicitly when exporting (#2615)
I think it would be nice to explicitly set opset_version when exporting, particularly when a custom/particular Opset is being used and the default opset can't be inferred. Example: ```py from onnxscript import script from onnxscript import opset15 as op from onnxscript.values import Opset import numpy as np from onnxscript import STRING from onnxruntime import InferenceSession ai_onnx = Opset("ai.onnx.ml", version=2) @script(ai_onnx, default_opset = op) def label_encoder(X: STRING["D"]): Y = ai_onnx.LabelEncoder(X, keys_strings=["a", "b", "c"], values_int64s=[0, 1, 2], default_int64=42) # Y = Y + 0.0 # to force opset version downgrade return Y print(label_encoder(np.array(["a", "b", "c"]))) session = InferenceSession(label_encoder.to_model_proto(ir_version=10).SerializeToString()) for key, value in {"a": 0, "b": 1, "c": 2}.items(): assert label_encoder(np.array([key]))[0] == value assert session.run(None, {"X": np.array([key])})[0] == value ``` This currently errors with ```sh Traceback (most recent call last): File "/Users/XXX/Development/projects/jet/test_onnxscript_label.py", line 25, in <module> session = InferenceSession(label_encoder.to_model_proto(ir_version=10).SerializeToString()) File "/Users/XXX/Development/projects/jet/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 472, in __init__ self._create_inference_session(providers, provider_options, disabled_optimizers) File "/Users/XXX/Development/projects/jet/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 552, in _create_inference_session sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model) onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : /Users/runner/work/1/s/onnxruntime/core/graph/model_load_utils.h:56 void onnxruntime::model_load_utils::ValidateOpsetForDomain(const std::unordered_map<std::string, int> &, const logging::Logger &, bool, const std::string &, int) ONNX Runtime only *guarantees* support for models stamped with official released onnx opset versions. Opset 23 is under development and support for this is limited. The operator schemas and or other functionality may change before next ONNX release and in this case ONNX Runtime will not guarantee backward compatibility. Current official support for domain ai.onnx is till opset 22. ``` To force it to work in the current state, one would have to do: ```py @script(ai_onnx, default_opset = op) def label_encoder(X: STRING["D"]): Y = ai_onnx.LabelEncoder(X, keys_strings=["a", "b", "c"], values_int64s=[0, 1, 2], default_int64=42) Y = Y + 0.0 # to force opset version downgrade/inserted from the to_model_proto call return Y ``` To force the opset to be downgraded, since the `default_opset` is never called. Happy to be challenged if there is a better way. I can imagine something weird/unintended might occur if the user sets `default_opset` to something other than what is defined in `@script(..., default_opset=<op>)` but that generally shouldn't be a problem?
1 parent dd14682 commit f44b314

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

onnxscript/irbuilder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def to_model_proto(
321321
input_types: Optional[Sequence[ONNXType]] = None,
322322
output_types: Optional[Sequence[ONNXType]] = None,
323323
value_infos: dict[str, ONNXType] | None = None,
324+
opset_version: int | None = None,
324325
**kwargs,
325326
) -> onnx.ModelProto:
326327
"""Converts this instance into a `onnx.ModelProto`.
@@ -336,6 +337,8 @@ def to_model_proto(
336337
are set to be of the corresponding type in this list.
337338
value_infos: A dictionary mapping intermediate variable names to ONNX types.
338339
Used to set value_info for intermediate variables.
340+
opset_version: The standard opset version to use for the model if it
341+
cannot be inferred. Otherwise defaults to the current opset version.
339342
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
340343
341344
Returns:
@@ -393,8 +396,8 @@ def to_proto(f):
393396

394397
if "" not in opsets:
395398
# No operator is using the standard opset.
396-
# A default value is given.
397-
opsets[""] = onnx_opset_version()
399+
# Use the specified version if provided or the default value.
400+
opsets[""] = opset_version if opset_version is not None else onnx_opset_version()
398401

399402
if "ir_version" not in kwargs:
400403
kwargs["ir_version"] = select_ir_version(opsets[""])

0 commit comments

Comments
 (0)