Skip to content

Commit 9036fab

Browse files
[Rewriter] Support specifying node name in rewrites (#2474)
Allows passing a node name when defining a rewrite. fixes #2435 --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent d0fb218 commit 9036fab

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

onnxscript/ir/_tape.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717

1818

1919
class Builder(tape.Tape):
20-
"""An extension of the tape that provides a more convenient API for constructing the IR."""
20+
"""An extension of the tape that provides a more convenient API for constructing the IR.
21+
22+
Example:
23+
>>> from onnxscript import ir
24+
>>> from onnxscript.ir import _tape
25+
>>> op = _tape.Builder()
26+
>>> input = ir.Value(name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)))
27+
>>> relu_val = op.Relu(input, _name="relu_node", _domain="", _version=18, _outputs=["relu_out"])
28+
29+
Note: When passing `_name`, ensure it is unique to avoid duplicate node names.
30+
"""
2131

2232
def __getattr__(self, op_type: str) -> Any:
2333
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
@@ -26,6 +36,8 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
2636
domain = kwargs.pop("_domain", "")
2737
version = kwargs.pop("_version", None)
2838
outputs = kwargs.pop("_outputs", 1)
39+
name = kwargs.pop("_name", None)
40+
2941
if isinstance(outputs, Sequence):
3042
num_outputs = len(outputs)
3143
else:
@@ -34,7 +46,12 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
3446

3547
if num_outputs == 1:
3648
value = super().op(
37-
op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version
49+
op_type,
50+
inputs=inputs,
51+
attributes=kwargs,
52+
domain=domain,
53+
version=version,
54+
name=name,
3855
)
3956
if isinstance(outputs, Sequence):
4057
value.name = outputs[0]
@@ -45,6 +62,7 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
4562
attributes=kwargs,
4663
domain=domain,
4764
version=version,
65+
name=name,
4866
num_outputs=num_outputs,
4967
)
5068
if isinstance(outputs, Sequence):

onnxscript/ir/_tape_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66

77
from onnxscript import ir
8+
from onnxscript.ir import _tape
89

910

1011
class TestTape(unittest.TestCase):
@@ -72,5 +73,32 @@ def test_op_multi_out(self):
7273
self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"])
7374

7475

76+
class TestBuilder(unittest.TestCase):
77+
def test_op_name(self):
78+
op = _tape.Builder()
79+
80+
input_a = ir.Value(
81+
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
82+
)
83+
input_b = ir.Value(
84+
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
85+
)
86+
87+
add = op.Add(input_a, input_b, _name="add_node")
88+
_ = op.Relu(add, _name="relu_node")
89+
self.assertEqual(op.nodes[0].name, "add_node")
90+
self.assertEqual(op.nodes[1].name, "relu_node")
91+
92+
def test_op_name_multi_out(self):
93+
op = _tape.Builder()
94+
95+
input_a = ir.Value(
96+
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
97+
)
98+
99+
_ = op.CustomOp(input_a, _name="custom_node", _outputs=3)
100+
self.assertEqual(op.nodes[0].name, "custom_node")
101+
102+
75103
if __name__ == "__main__":
76104
unittest.main()

0 commit comments

Comments
 (0)