Skip to content

Commit 710d597

Browse files
authored
Fix rewriter and CI tests for the latest onnx-ir version (#2554)
Fix rewriter CI tests for the latest onnx-ir version (currently in main). Since the latest onnx-ir is now returning tuples for repeated attributes, we need to update the comparison logic to account for that. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 821015a commit 710d597

File tree

12 files changed

+28
-21
lines changed

12 files changed

+28
-21
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Dim = Union[int, ir.SymbolicDim]
1414

1515

16-
def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
16+
def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
1717
if val.shape is None:
1818
return False
1919
if val.shape.rank() != len(shape):

onnxscript/rewriter/_pattern_ir.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,14 @@ def __init__(self, value: SupportedAttrTypes):
126126
self._value = value
127127

128128
def matches(self, attr: ir.Attr) -> bool:
129-
return isinstance(attr, ir.Attr) and attr.value == self._value
129+
if attr.type in {
130+
ir.AttributeType.INTS,
131+
ir.AttributeType.FLOATS,
132+
ir.AttributeType.STRINGS,
133+
}:
134+
# Since the type of attr.value is Sequence, we need to convert to the same type for comparison.
135+
return tuple(attr.value) == tuple(self._value)
136+
return attr.value == self._value
130137

131138
def __str__(self) -> str:
132139
return str(self._value)

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
392392
if perm.is_ref():
393393
return False
394394
if perm.type == ir.AttributeType.INTS:
395-
if perm.as_ints() == list(range(len(perm.as_ints()))):
395+
if list(perm.as_ints()) == list(range(len(perm.as_ints()))):
396396
return True
397397
return False
398398
"""
@@ -463,7 +463,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
463463
if perm.is_ref():
464464
return False
465465
if perm.type == ir.AttributeType.INTS:
466-
if perm.as_ints() == list(range(len(perm.as_ints()))):
466+
if list(perm.as_ints()) == list(range(len(perm.as_ints()))):
467467
return True
468468
return False
469469

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def check(
160160
self.bindings: dict[str, Dim] = {}
161161

162162
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
163-
return not _fusion_utils._check_shape(self.bindings, val, dims)
163+
return not _fusion_utils.check_shape_bool(self.bindings, val, dims)
164164

165165
if no_match(input, ["B", "S", "D"]):
166166
return check_result.fail(

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def check(
7979
# Check that last two dimensions are swapped
8080
expected_perm = list(range(len(perm)))
8181
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
82-
if perm != expected_perm:
82+
if list(perm) != expected_perm:
8383
return check_result.fail("Permutation values for Transpose are not correct.")
8484
elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or (
8585
self._pos == 2 and not _ir_utils.has_rank(y, 2)
@@ -188,7 +188,7 @@ def check(
188188
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
189189
trans_batch = fused_node.attributes.get_int(trans_batch_property, 0)
190190
transposed_node = _get_node(transposed, "Transpose")
191-
perm = transposed_node.attributes["perm"].as_ints()
191+
perm = list(transposed_node.attributes["perm"].as_ints())
192192
if not perm:
193193
return check_result.fail("Permutation values for Transpose are not correct.")
194194

@@ -296,7 +296,7 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
296296
if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2):
297297
if perm:
298298
# Check that the two dimensions are swapped
299-
if perm != [1, 0]:
299+
if tuple(perm) != (1, 0):
300300
return check_result.fail(
301301
"Permutation values for Transpose are not correct."
302302
)

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def _check_model(
284284
opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul])
285285
expected = ref.run(None, feeds)
286286
got = opt.run(None, feeds)
287-
self.assertEqual(len(expected), len(got))
287+
self.assertEqual(len(got), len(expected))
288288
for a, b in zip(expected, got):
289289
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
290290

@@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty
319319
rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets()
320320
rule_set.apply_to_model(ir_model)
321321
rewritten_model = ir.serde.serialize_model(ir_model)
322-
self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph])
322+
self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"])
323323
self._check_model(model_proto, rewritten_model, atol=1e-6)
324324

325325
@parameterized.parameterized.expand(
@@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func):
354354
ir_model = ir.serde.deserialize_model(model_proto)
355355
self._apply_fusion_rules(ir_model)
356356
rewritten_model = ir.serde.serialize_model(ir_model)
357-
self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph])
357+
self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"])
358358
self._check_model(model_proto, rewritten_model, atol=1e-6)
359359

360360
@parameterized.parameterized.expand([("should_not_match", _should_not_match)])
@@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func):
366366
self._apply_fusion_rules(ir_model)
367367
rewritten_model = ir.serde.serialize_model(ir_model)
368368
self.assertEqual(
369-
["Transpose", "MatMul", "Transpose"],
370369
[n.op_type for n in ir_model.graph],
370+
["Transpose", "MatMul", "Transpose"],
371371
)
372372
self._check_model(model_proto, rewritten_model, atol=1e-6)
373373

@@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func):
391391
common_passes.ShapeInferencePass()(ir_model)
392392
self._apply_fusion_rules(ir_model)
393393
rewritten_model = ir.serde.serialize_model(ir_model)
394-
self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph])
394+
self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"])
395395
self._check_model(model_proto, rewritten_model, atol=1e-6)
396396

397397
@parameterized.parameterized.expand(
@@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func):
440440
ir_model = ir.serde.deserialize_model(model_proto)
441441
self._apply_fusion_rules(ir_model)
442442
rewritten_model = ir.serde.serialize_model(ir_model)
443-
self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph])
443+
self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"])
444444
self._check_model(model_proto, rewritten_model, atol=1e-6)
445445

446446

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def check(
247247
bindings: dict[str, Dim] = {}
248248

249249
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
250-
return not _fusion_utils._check_shape(bindings, val, dims)
250+
return not _fusion_utils.check_shape_bool(bindings, val, dims)
251251

252252
if no_match(query_BSD, ["B", "S", "D"]):
253253
return False

onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def check(
8484
self.bindings: dict[str, Dim] = {}
8585

8686
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
87-
return not _fusion_utils._check_shape(self.bindings, val, dims)
87+
return not _fusion_utils.check_shape_bool(self.bindings, val, dims)
8888

8989
# Check that if x is being split into q, k, v correctly
9090
# based on hidden sizes

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def check(
157157
bindings: dict[str, Dim] = {}
158158

159159
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
160-
return not _fusion_utils._check_shape(bindings, val, dims)
160+
return not _fusion_utils.check_shape_bool(bindings, val, dims)
161161

162162
if no_match(query_BSD, ["B", "S", "D"]):
163163
return check_result.fail(

onnxscript/rewriter/ort_fusions/mha_bias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def check(
7878
self.bindings: dict[str, Dim] = {}
7979

8080
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
81-
return not _fusion_utils._check_shape(self.bindings, val, dims)
81+
return not _fusion_utils.check_shape_bool(self.bindings, val, dims)
8282

8383
if query_matmul.dtype not in valid_float_types:
8484
return check_result.fail("Query is not a float or float16 type.", query_matmul)

0 commit comments

Comments
 (0)