|
10 | 10 | import onnx.parser |
11 | 11 |
|
12 | 12 | import onnxscript.optimizer |
| 13 | +import onnxscript.rewriter |
13 | 14 | from onnxscript import FLOAT, ir, script |
14 | 15 | from onnxscript import opset17 as op |
15 | 16 | from onnxscript.rewriter import pattern |
@@ -936,6 +937,44 @@ def add_pattern(op, x, y): |
936 | 937 | match_result = rule_pattern.match(model, model.graph, add_nodes[2]) |
937 | 938 | self.assertFalse(bool(match_result)) |
938 | 939 |
|
| 940 | + def test_rule_name_metadata(self): |
| 941 | + """Test that RewriteRule carries name metadata.""" |
| 942 | + |
| 943 | + class ReciprocalMulRule(pattern.RewriteRuleClassBase): |
| 944 | + def __init__(self, name: str | None = None): |
| 945 | + super().__init__(name) |
| 946 | + |
| 947 | + def pattern(self, op, x, y): |
| 948 | + return (1 / x) * y |
| 949 | + |
| 950 | + def rewrite(self, op, x, y): |
| 951 | + return op.Div(y, x) |
| 952 | + |
| 953 | + @script() |
| 954 | + def test_script(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: |
| 955 | + return op.Mul(op.Div(op.Constant(value_float=1.0), x), y) |
| 956 | + |
| 957 | + rule = ReciprocalMulRule.rule(name="ReciprocalMulToDiv") |
| 958 | + model_proto = test_script.to_model_proto() |
| 959 | + model = ir.serde.deserialize_model(model_proto) |
| 960 | + count = rule.apply_to_model(model) |
| 961 | + self.assertEqual(count, 1) |
| 962 | + for node in model.graph: |
| 963 | + if node.op_type == "Div": |
| 964 | + tag = onnxscript.rewriter.RULE_NAME_TAG |
| 965 | + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulToDiv") |
| 966 | + |
| 967 | + # By default, the rule name is the class name (if not provided) |
| 968 | + rule = ReciprocalMulRule.rule() |
| 969 | + model_proto = test_script.to_model_proto() |
| 970 | + model = ir.serde.deserialize_model(model_proto) |
| 971 | + count = rule.apply_to_model(model) |
| 972 | + self.assertEqual(count, 1) |
| 973 | + for node in model.graph: |
| 974 | + if node.op_type == "Div": |
| 975 | + tag = onnxscript.rewriter.RULE_NAME_TAG |
| 976 | + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulRule") |
| 977 | + |
939 | 978 |
|
940 | 979 | class PatternBuilderTest(unittest.TestCase): |
941 | 980 | def test_pattern_builder_context(self): |
|
0 commit comments