11import unittest
22from textwrap import dedent
33import numpy as np
4+ import onnx .helper as oh
45from onnx import ModelProto , TensorProto
56from onnx .checker import check_model
67from onnx .defs import onnx_opset_version
@@ -29,8 +30,9 @@ def test_exp(self):
2930 self .assertEqualArray (np .exp (a ), got )
3031
3132 code = translate (onx , api = "builder" )
32- expected = dedent (
33- """
33+ expected = (
34+ dedent (
35+ """
3436 def light_api(
3537 op: "GraphBuilder",
3638 X: "FLOAT[]",
@@ -42,10 +44,13 @@ def light_api(
4244 g = GraphBuilder({'': 19}, ir_version=10)
4345 g.make_tensor_input("X", TensorProto.FLOAT, ())
4446 light_api(g.op, "X")
45- g.make_tensor_output("Y", TensorProto.FLOAT, ())
47+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
4648 model = g.to_onnx()
4749 """
48- ).strip ("\n " )
50+ )
51+ .strip ("\n " )
52+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
53+ )
4954 self .assertEqual (expected , code .strip ("\n " ))
5055
5156 def light_api (
@@ -59,7 +64,9 @@ def light_api(
5964 g2 = GraphBuilder ({"" : 19 })
6065 g2 .make_tensor_input ("X" , TensorProto .FLOAT , ("A" ,))
6166 light_api (g2 .op , "X" )
62- g2 .make_tensor_output ("Y" , TensorProto .FLOAT , ("A" ,))
67+ g2 .make_tensor_output (
68+ "Y" , TensorProto .FLOAT , ("A" ,), is_dimension = False , indexed = False
69+ )
6370 onx2 = g2 .to_onnx ()
6471
6572 ref = ReferenceEvaluator (onx2 )
@@ -78,8 +85,9 @@ def test_zdoc(self):
7885 .to_onnx ()
7986 )
8087 code = translate (onx , api = "builder" )
81- expected = dedent (
82- """
88+ expected = (
89+ dedent (
90+ """
8391 def light_api(
8492 op: "GraphBuilder",
8593 X: "FLOAT[]",
@@ -93,10 +101,13 @@ def light_api(
93101 g = GraphBuilder({'': 19}, ir_version=10)
94102 g.make_tensor_input("X", TensorProto.FLOAT, ())
95103 light_api(g.op, "X")
96- g.make_tensor_output("Y", TensorProto.FLOAT, ())
104+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
97105 model = g.to_onnx()
98106 """
99- ).strip ("\n " )
107+ )
108+ .strip ("\n " )
109+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
110+ )
100111 self .maxDiff = None
101112 self .assertEqual (expected , code .strip ("\n " ))
102113
@@ -130,8 +141,9 @@ def test_exp_f(self):
130141 tr = Translater (onx , emitter = BuilderEmitter ("mm" ))
131142 code = tr .export (as_str = True )
132143
133- expected = dedent (
134- """
144+ expected = (
145+ dedent (
146+ """
135147 def light_api(
136148 op: "GraphBuilder",
137149 X: "FLOAT[]",
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145157 g = GraphBuilder({'': 19}, ir_version=10)
146158 g.make_tensor_input("X", TensorProto.FLOAT, ())
147159 light_api(g.op, "X")
148- g.make_tensor_output("Y", TensorProto.FLOAT, ())
160+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
149161 model = g.to_onnx()
150162 return model
151163
152164
153165 model = mm()
154166 """
155- ).strip ("\n " )
167+ )
168+ .strip ("\n " )
169+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
170+ )
156171 self .assertEqual (expected , code .strip ("\n " ))
157172
158173 def light_api (
@@ -166,14 +181,104 @@ def light_api(
166181 g2 = GraphBuilder ({"" : 19 })
167182 g2 .make_tensor_input ("X" , TensorProto .FLOAT , ("A" ,))
168183 light_api (g2 .op , "X" )
169- g2 .make_tensor_output ("Y" , TensorProto .FLOAT , ("A" ,))
184+ g2 .make_tensor_output (
185+ "Y" , TensorProto .FLOAT , ("A" ,), is_dimension = False , indexed = False
186+ )
170187 onx2 = g2 .to_onnx ()
171188
172189 ref = ReferenceEvaluator (onx2 )
173190 a = np .arange (10 ).astype (np .float32 )
174191 got = ref .run (None , {"X" : a })[0 ]
175192 self .assertEqualArray (np .exp (a ), got )
176193
194+ def test_local_function (self ):
195+ new_domain = "custom"
196+
197+ linear_regression = oh .make_function (
198+ new_domain ,
199+ "LinearRegression" ,
200+ ["x" , "a" , "b" ],
201+ ["y" ],
202+ [
203+ oh .make_node ("MatMul" , ["x" , "a" ], ["xa" ]),
204+ oh .make_node ("Add" , ["xa" , "b" ], ["y" ]),
205+ ],
206+ [oh .make_opsetid ("" , 14 )],
207+ [],
208+ )
209+
210+ graph = oh .make_graph (
211+ [
212+ oh .make_node (
213+ "LinearRegression" , ["X" , "A" , "B" ], ["Y1" ], domain = new_domain
214+ ),
215+ oh .make_node ("Abs" , ["Y1" ], ["Y" ]),
216+ ],
217+ "example" ,
218+ [
219+ oh .make_tensor_value_info ("X" , TensorProto .FLOAT , [None , None ]),
220+ oh .make_tensor_value_info ("A" , TensorProto .FLOAT , [None , None ]),
221+ oh .make_tensor_value_info ("B" , TensorProto .FLOAT , [None , None ]),
222+ ],
223+ [oh .make_tensor_value_info ("Y" , TensorProto .FLOAT , None )],
224+ )
225+
226+ onnx_model = oh .make_model (
227+ graph ,
228+ opset_imports = [oh .make_opsetid ("" , 14 ), oh .make_opsetid (new_domain , 1 )],
229+ functions = [linear_regression ],
230+ )
231+ tr = Translater (onnx_model , emitter = BuilderEmitter ("mm" ))
232+ code = tr .export (as_str = True )
233+
234+ expected = (
235+ dedent (
236+ """
237+ def example(
238+ op: "GraphBuilder",
239+ X: "FLOAT[, ]",
240+ A: "FLOAT[, ]",
241+ B: "FLOAT[, ]",
242+ ):
243+ Y1 = op.LinearRegression(X, A, B, domain='custom')
244+ Y = op.Abs(Y1)
245+ op.Identity(Y, outputs=["Y"])
246+ return Y
247+
248+
249+ def make_custom_LinearRegression(g: "GraphBuilder"):
250+ gr = GraphBuilder({'': 14}, as_function=True)
251+ x = gr.make_tensor_input('x')
252+ a = gr.make_tensor_input('a')
253+ b = gr.make_tensor_input('b')
254+ op = gr.op
255+ xa = op.MatMul(x, a)
256+ y = op.Add(xa, b)
257+ gr.make_tensor_output(y)
258+ g.add_function(builder=gr)
259+ return gr
260+
261+
262+ def mm() -> "ModelProto":
263+ g = GraphBuilder({'': 14, 'custom': 1}, ir_version=11)
264+ g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
265+ g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
266+ g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
267+ example(g.op, "X", "A", "B")
268+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
269+ make_custom_LinearRegression(g)
270+ model = g.to_onnx()
271+ return model
272+
273+
274+ model = mm()
275+ """
276+ )
277+ .strip ("\n " )
278+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
279+ )
280+ self .assertEqual (expected , code .strip ("\n " ))
281+
177282
178283if __name__ == "__main__" :
179284 unittest .main (verbosity = 2 )
0 commit comments