From 307c47d6af19fa7b60930c28b13cdd58039a6688 Mon Sep 17 00:00:00 2001 From: chapman73 <97487899+chapman73@users.noreply.github.com> Date: Mon, 27 Oct 2025 11:33:35 +1100 Subject: [PATCH 1/3] Add utility function for converting model protos to function proto --- .../utils/model_proto_to_function_proto.py | 56 +++++++++++++++ tests/utils/__init__.py | 2 + .../model_proto_to_function_proto_test.py | 70 +++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 onnxscript/utils/model_proto_to_function_proto.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/model_proto_to_function_proto_test.py diff --git a/onnxscript/utils/model_proto_to_function_proto.py b/onnxscript/utils/model_proto_to_function_proto.py new file mode 100644 index 0000000000..e396c6d003 --- /dev/null +++ b/onnxscript/utils/model_proto_to_function_proto.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import onnx +import onnx_ir +from onnx import helper + + +def _initializers_to_constants(model: onnx.ModelProto) -> onnx.ModelProto: + graph = model.graph + new_nodes = [] + + # Keep track of names to remove from inputs + init_names = {init.name for init in graph.initializer} + + for init in graph.initializer: + # Convert initializer to Constant node + const_node = helper.make_node( + "Constant", + inputs=[], + outputs=[init.name], + value=init, # Directly use TensorProto + ) + new_nodes.append(const_node) + + # Filter out initializer names from graph inputs + filtered_inputs = [i for i in graph.input if i.name not in init_names] + graph.ClearField("input") + graph.input.extend(filtered_inputs) + + # Add new Constant nodes at the beginning + all_nodes = new_nodes + list(graph.node) + graph.ClearField("node") + graph.node.extend(all_nodes) + + # Clear initializers (since we replaced them) + graph.ClearField("initializer") + + return model + + +def convert_model_proto_to_function_proto( + model: onnx.ModelProto, domain, name +) -> onnx.FunctionProto: + """Converts an arbitrary ModelProto to a FunctionProto. + + Since function protos don't support initializers (or rather it does not make sense in the context of a function) + we need to convert them to constants first. + """ + model = _initializers_to_constants( + model + ) # theres some work to do here...maybe contribute to open source? + model_ir = onnx_ir.serde.deserialize_model(model) + function_ir = onnx_ir.Function( + domain=domain, name=name, graph=model_ir.graph, attributes={} + ) + return onnx_ir.to_proto(function_ir) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/utils/model_proto_to_function_proto_test.py b/tests/utils/model_proto_to_function_proto_test.py new file mode 100644 index 0000000000..7c5b1695ae --- /dev/null +++ b/tests/utils/model_proto_to_function_proto_test.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnxruntime as ort + +from onnxscript import script +from onnxscript.onnx_opset import opset15 as op +from onnxscript.onnx_types import FLOAT +from onnxscript.utils.model_proto_to_function_proto import ( + convert_model_proto_to_function_proto, +) +from onnxscript.values import Opset + + +class TestModelProtoToFunctionProto(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + # Create a fresh custom opset for each test + self.local = Opset("local", 1) + + # Define test functions + @script(self.local, default_opset=op) + def diff_square(x, y): + diff = x - y + return diff * diff + + @script(self.local) + def sum_func(z): + return op.ReduceSum(z, keepdims=1) + + @script() + def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 + return op.Sqrt(sum_func(diff_square(x, y))) + + @script() + def l2norm_with_functions(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 + return op.Sqrt(sum_func(diff_square(x, y))) + + self.diff_square = diff_square + self.sum_func = sum_func + self.l2norm = l2norm + self.l2norm_with_functions = l2norm_with_functions + + def test_multiple_functions_in_model_proto(self): + """Test that multiple functions can be included in a single model proto.""" + # Add sum function to opset + sum_model = self.sum_func.to_model_proto() + sum_function_proto = convert_model_proto_to_function_proto( + sum_model, "local", "sum_func" + ) + + model = self.l2norm_with_functions.to_model_proto( + functions=[sum_function_proto, self.diff_square] + ) + + # Test execution + session = ort.InferenceSession(model.SerializeToString()) + result = session.run( + None, + { + "x": np.array([1.0, 2.0, 3.0]).astype(np.float32), + "y": np.array([4.0, 5.0, 6.0]).astype(np.float32), + }, + ) + + # Verify result + self.assertEqual(len(result), 1) + self.assertAlmostEqual(np.sqrt(27.0), result[0][0], places=5) # L2 norm of [3, 3, 3] From 791cdbc41514673387480cbee9bdc9341c74c684 Mon Sep 17 00:00:00 2001 From: chapman73 <97487899+chapman73@users.noreply.github.com> Date: Wed, 29 Oct 2025 20:03:13 +1100 Subject: [PATCH 2/3] Addr. comments --- onnxscript/utils/model_proto_to_function_proto.py | 6 ++---- tests/utils/model_proto_to_function_proto_test.py | 5 ----- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/onnxscript/utils/model_proto_to_function_proto.py b/onnxscript/utils/model_proto_to_function_proto.py index e396c6d003..79f73b7373 100644 --- a/onnxscript/utils/model_proto_to_function_proto.py +++ b/onnxscript/utils/model_proto_to_function_proto.py @@ -39,16 +39,14 @@ def _initializers_to_constants(model: onnx.ModelProto) -> onnx.ModelProto: def convert_model_proto_to_function_proto( - model: onnx.ModelProto, domain, name + model: onnx.ModelProto, domain: str, name: str ) -> onnx.FunctionProto: """Converts an arbitrary ModelProto to a FunctionProto. Since function protos don't support initializers (or rather it does not make sense in the context of a function) we need to convert them to constants first. """ - model = _initializers_to_constants( - model - ) # theres some work to do here...maybe contribute to open source? + model = _initializers_to_constants(model) model_ir = onnx_ir.serde.deserialize_model(model) function_ir = onnx_ir.Function( domain=domain, name=name, graph=model_ir.graph, attributes={} diff --git a/tests/utils/model_proto_to_function_proto_test.py b/tests/utils/model_proto_to_function_proto_test.py index 7c5b1695ae..de627974c7 100644 --- a/tests/utils/model_proto_to_function_proto_test.py +++ b/tests/utils/model_proto_to_function_proto_test.py @@ -30,17 +30,12 @@ def diff_square(x, y): def sum_func(z): return op.ReduceSum(z, keepdims=1) - @script() - def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 - return op.Sqrt(sum_func(diff_square(x, y))) - @script() def l2norm_with_functions(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 return op.Sqrt(sum_func(diff_square(x, y))) self.diff_square = diff_square self.sum_func = sum_func - self.l2norm = l2norm self.l2norm_with_functions = l2norm_with_functions def test_multiple_functions_in_model_proto(self): From aadc8815c8028a8524da099ea4fdd5a3591764cf Mon Sep 17 00:00:00 2001 From: chapman73 <97487899+chapman73@users.noreply.github.com> Date: Sun, 2 Nov 2025 15:32:55 +1100 Subject: [PATCH 3/3] Update based on comments --- onnxscript/utils/model_proto_to_function_proto.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/onnxscript/utils/model_proto_to_function_proto.py b/onnxscript/utils/model_proto_to_function_proto.py index 79f73b7373..4583c6e8e1 100644 --- a/onnxscript/utils/model_proto_to_function_proto.py +++ b/onnxscript/utils/model_proto_to_function_proto.py @@ -48,7 +48,15 @@ def convert_model_proto_to_function_proto( """ model = _initializers_to_constants(model) model_ir = onnx_ir.serde.deserialize_model(model) + doc_string = model.doc_string if model.doc_string else "" + graph = model_ir.graph + function_ir = onnx_ir.Function( - domain=domain, name=name, graph=model_ir.graph, attributes={} + domain=domain, name=name, graph=graph, attributes={} ) + + # set metadata + function_ir.doc_string(doc_string) + function_ir.metadata_props = graph.metadata_props # I believe the docs suggest directly modifying the attr? + # function_ir.value_infos(graph.value_infos) # theres no setter defined? return onnx_ir.to_proto(function_ir)