diff --git a/onnxscript/utils/model_proto_to_function_proto.py b/onnxscript/utils/model_proto_to_function_proto.py new file mode 100644 index 0000000000..4583c6e8e1 --- /dev/null +++ b/onnxscript/utils/model_proto_to_function_proto.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import onnx +import onnx_ir +from onnx import helper + + +def _initializers_to_constants(model: onnx.ModelProto) -> onnx.ModelProto: + graph = model.graph + new_nodes = [] + + # Keep track of names to remove from inputs + init_names = {init.name for init in graph.initializer} + + for init in graph.initializer: + # Convert initializer to Constant node + const_node = helper.make_node( + "Constant", + inputs=[], + outputs=[init.name], + value=init, # Directly use TensorProto + ) + new_nodes.append(const_node) + + # Filter out initializer names from graph inputs + filtered_inputs = [i for i in graph.input if i.name not in init_names] + graph.ClearField("input") + graph.input.extend(filtered_inputs) + + # Add new Constant nodes at the beginning + all_nodes = new_nodes + list(graph.node) + graph.ClearField("node") + graph.node.extend(all_nodes) + + # Clear initializers (since we replaced them) + graph.ClearField("initializer") + + return model + + +def convert_model_proto_to_function_proto( + model: onnx.ModelProto, domain: str, name: str +) -> onnx.FunctionProto: + """Converts an arbitrary ModelProto to a FunctionProto. + + Since function protos don't support initializers (or rather it does not make sense in the context of a function) + we need to convert them to constants first. + """ + model = _initializers_to_constants(model) + model_ir = onnx_ir.serde.deserialize_model(model) + doc_string = model.doc_string if model.doc_string else "" + graph = model_ir.graph + + function_ir = onnx_ir.Function( + domain=domain, name=name, graph=graph, attributes={} + ) + + # set metadata + function_ir.doc_string(doc_string) + function_ir.metadata_props = graph.metadata_props # I believe the docs suggest directly modifying the attr? + # function_ir.value_infos(graph.value_infos) # theres no setter defined? + return onnx_ir.to_proto(function_ir) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/utils/model_proto_to_function_proto_test.py b/tests/utils/model_proto_to_function_proto_test.py new file mode 100644 index 0000000000..de627974c7 --- /dev/null +++ b/tests/utils/model_proto_to_function_proto_test.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnxruntime as ort + +from onnxscript import script +from onnxscript.onnx_opset import opset15 as op +from onnxscript.onnx_types import FLOAT +from onnxscript.utils.model_proto_to_function_proto import ( + convert_model_proto_to_function_proto, +) +from onnxscript.values import Opset + + +class TestModelProtoToFunctionProto(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + # Create a fresh custom opset for each test + self.local = Opset("local", 1) + + # Define test functions + @script(self.local, default_opset=op) + def diff_square(x, y): + diff = x - y + return diff * diff + + @script(self.local) + def sum_func(z): + return op.ReduceSum(z, keepdims=1) + + @script() + def l2norm_with_functions(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 + return op.Sqrt(sum_func(diff_square(x, y))) + + self.diff_square = diff_square + self.sum_func = sum_func + self.l2norm_with_functions = l2norm_with_functions + + def test_multiple_functions_in_model_proto(self): + """Test that multiple functions can be included in a single model proto.""" + # Add sum function to opset + sum_model = self.sum_func.to_model_proto() + sum_function_proto = convert_model_proto_to_function_proto( + sum_model, "local", "sum_func" + ) + + model = self.l2norm_with_functions.to_model_proto( + functions=[sum_function_proto, self.diff_square] + ) + + # Test execution + session = ort.InferenceSession(model.SerializeToString()) + result = session.run( + None, + { + "x": np.array([1.0, 2.0, 3.0]).astype(np.float32), + "y": np.array([4.0, 5.0, 6.0]).astype(np.float32), + }, + ) + + # Verify result + self.assertEqual(len(result), 1) + self.assertAlmostEqual(np.sqrt(27.0), result[0][0], places=5) # L2 norm of [3, 3, 3]