Skip to content

Commit 96d7a1c

Browse files
committed
Add utility function for converting model protos to function proto
1 parent 8a94ad6 commit 96d7a1c

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import onnx
4+
import onnx_ir
5+
from onnx import helper
6+
7+
8+
def _initializers_to_constants(model: onnx.ModelProto) -> onnx.ModelProto:
9+
graph = model.graph
10+
new_nodes = []
11+
12+
# Keep track of names to remove from inputs
13+
init_names = {init.name for init in graph.initializer}
14+
15+
for init in graph.initializer:
16+
# Convert initializer to Constant node
17+
const_node = helper.make_node(
18+
"Constant",
19+
inputs=[],
20+
outputs=[init.name],
21+
value=init, # Directly use TensorProto
22+
)
23+
new_nodes.append(const_node)
24+
25+
# Filter out initializer names from graph inputs
26+
filtered_inputs = [i for i in graph.input if i.name not in init_names]
27+
graph.ClearField("input")
28+
graph.input.extend(filtered_inputs)
29+
30+
# Add new Constant nodes at the beginning
31+
all_nodes = new_nodes + list(graph.node)
32+
graph.ClearField("node")
33+
graph.node.extend(all_nodes)
34+
35+
# Clear initializers (since we replaced them)
36+
graph.ClearField("initializer")
37+
38+
return model
39+
40+
41+
def convert_model_proto_to_function_proto(
42+
model: onnx.ModelProto, domain, name
43+
) -> onnx.FunctionProto:
44+
"""Converts an arbitrary ModelProto to a FunctionProto.
45+
46+
Since function protos don't support initializers (or rather it does not make sense in the context of a function)
47+
we need to convert them to constants first.
48+
"""
49+
model = _initializers_to_constants(
50+
model
51+
) # theres some work to do here...maybe contribute to open source?
52+
model_ir = onnx_ir.serde.deserialize_model(model)
53+
function_ir = onnx_ir.Function(
54+
domain=domain, name=name, graph=model_ir.graph, attributes={}
55+
)
56+
return onnx_ir.to_proto(function_ir)

tests/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import unittest
4+
5+
import numpy as np
6+
import onnxruntime as ort
7+
8+
from onnxscript import script
9+
from onnxscript.onnx_opset import opset15 as op
10+
from onnxscript.onnx_types import FLOAT
11+
from onnxscript.utils.model_proto_to_function_proto import (
12+
convert_model_proto_to_function_proto,
13+
)
14+
from onnxscript.values import Opset
15+
16+
17+
class TestModelProtoToFunctionProto(unittest.TestCase):
18+
def setUp(self):
19+
"""Set up test fixtures."""
20+
# Create a fresh custom opset for each test
21+
self.local = Opset("local", 1)
22+
23+
# Define test functions
24+
@script(self.local, default_opset=op)
25+
def diff_square(x, y):
26+
diff = x - y
27+
return diff * diff
28+
29+
@script(self.local)
30+
def sum_func(z):
31+
return op.ReduceSum(z, keepdims=1)
32+
33+
@script()
34+
def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821
35+
return op.Sqrt(sum_func(diff_square(x, y)))
36+
37+
@script()
38+
def l2norm_with_functions(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821
39+
return op.Sqrt(sum_func(diff_square(x, y)))
40+
41+
self.diff_square = diff_square
42+
self.sum_func = sum_func
43+
self.l2norm = l2norm
44+
self.l2norm_with_functions = l2norm_with_functions
45+
46+
def test_multiple_functions_in_model_proto(self):
47+
"""Test that multiple functions can be included in a single model proto."""
48+
# Add sum function to opset
49+
sum_model = self.sum_func.to_model_proto()
50+
sum_function_proto = convert_model_proto_to_function_proto(
51+
sum_model, "local", "sum_func"
52+
)
53+
54+
model = self.l2norm_with_functions.to_model_proto(
55+
functions=[sum_function_proto, self.diff_square]
56+
)
57+
58+
# Test execution
59+
session = ort.InferenceSession(model.SerializeToString())
60+
result = session.run(
61+
None,
62+
{
63+
"x": np.array([1.0, 2.0, 3.0]).astype(np.float32),
64+
"y": np.array([4.0, 5.0, 6.0]).astype(np.float32),
65+
},
66+
)
67+
68+
# Verify result
69+
self.assertEqual(len(result), 1)
70+
self.assertAlmostEqual(np.sqrt(27.0), result[0][0], places=5) # L2 norm of [3, 3, 3]

0 commit comments

Comments
 (0)