Skip to content

Commit 93783ee

Browse files
authored
Capture rewrite rule name as metadata (#2675)
Capture rewrite rule name as metadata to simplify debugging issues from rewrites. This is just a basic version. TODO / Extensions: * Sometimes we apply a sequence of rewrite-rules one after another, to perform complex fusions. This currently records only the last rule applied. * This can be solved when we merge metadata from original nodes into new nodes. (See #2375 ) * May be useful standardize on a single ONNX metadata key for "source" info (that can be used by torchlib/other exporters/rewriter/optimizer etc.) --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 5be9d3b commit 93783ee

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"RewriterContext",
1717
"MatchingTracer",
1818
"MatchStatus",
19+
"RULE_NAME_TAG",
1920
]
2021

2122
import onnx
@@ -25,6 +26,7 @@
2526
from onnxscript.rewriter import pattern
2627
from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus
2728
from onnxscript.rewriter._rewrite_rule import (
29+
RULE_NAME_TAG,
2830
RewriterContext,
2931
RewriteRule,
3032
RewriteRuleClassBase,

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525

2626
RewriterContext = _tape.Builder
2727

28+
# TODO(rama): Standardize metadata property keys. May be worth standardizing at ONNX level for
29+
# source/producer metadata.
30+
31+
RULE_NAME_TAG = "pkg.onnxscript.rewriter.rule_name"
32+
2833

2934
@dataclasses.dataclass
3035
class ReplacementSubgraph:
@@ -719,6 +724,13 @@ def _apply_to_graph_or_function(
719724
_ir_utils.display_nodes(delta.new_nodes)
720725
print("++++End Replacement Nodes++++")
721726

727+
# Capture rewrite rule name as metadata.
728+
# TODO(rama): This is just a basic version. We may wish to compose "source" metadata
729+
# from multiple rules in future.
730+
if rule.name:
731+
for n in delta.new_nodes:
732+
n.metadata_props[RULE_NAME_TAG] = rule.name
733+
722734
convenience.replace_nodes_and_values(
723735
graph_or_function,
724736
node,

onnxscript/rewriter/pattern_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import onnx.parser
1111

1212
import onnxscript.optimizer
13+
import onnxscript.rewriter
1314
from onnxscript import FLOAT, ir, script
1415
from onnxscript import opset17 as op
1516
from onnxscript.rewriter import pattern
@@ -936,6 +937,44 @@ def add_pattern(op, x, y):
936937
match_result = rule_pattern.match(model, model.graph, add_nodes[2])
937938
self.assertFalse(bool(match_result))
938939

940+
def test_rule_name_metadata(self):
941+
"""Test that RewriteRule carries name metadata."""
942+
943+
class ReciprocalMulRule(pattern.RewriteRuleClassBase):
944+
def __init__(self, name: str | None = None):
945+
super().__init__(name)
946+
947+
def pattern(self, op, x, y):
948+
return (1 / x) * y
949+
950+
def rewrite(self, op, x, y):
951+
return op.Div(y, x)
952+
953+
@script()
954+
def test_script(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
955+
return op.Mul(op.Div(op.Constant(value_float=1.0), x), y)
956+
957+
rule = ReciprocalMulRule.rule(name="ReciprocalMulToDiv")
958+
model_proto = test_script.to_model_proto()
959+
model = ir.serde.deserialize_model(model_proto)
960+
count = rule.apply_to_model(model)
961+
self.assertEqual(count, 1)
962+
for node in model.graph:
963+
if node.op_type == "Div":
964+
tag = onnxscript.rewriter.RULE_NAME_TAG
965+
self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulToDiv")
966+
967+
# By default, the rule name is the class name (if not provided)
968+
rule = ReciprocalMulRule.rule()
969+
model_proto = test_script.to_model_proto()
970+
model = ir.serde.deserialize_model(model_proto)
971+
count = rule.apply_to_model(model)
972+
self.assertEqual(count, 1)
973+
for node in model.graph:
974+
if node.op_type == "Div":
975+
tag = onnxscript.rewriter.RULE_NAME_TAG
976+
self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulRule")
977+
939978

940979
class PatternBuilderTest(unittest.TestCase):
941980
def test_pattern_builder_context(self):

0 commit comments

Comments
 (0)