Skip to content

Commit 071ff1e

Browse files
justinchubyCopilot
andauthored
Create helper for comparing semantic equivalence of shapes (#2620)
This pull request introduces new utility functions for comparing shapes and dimensions in the intermediate representation (IR) utilities, and refactors existing rewrite rules to use these new utilities. The goal is to improve semantic correctness and code clarity when checking shape and dimension equality, especially in the presence of symbolic or unknown values. Key changes: **New IR utility functions:** * Added `same_shape` and `same_dim` functions to `_ir_utils.py` for more robust and semantically correct comparison of shapes and dimensions, accounting for unknown or symbolic values. **Refactoring of rewrite rules to use new utilities:** * Updated `_collapse_slices.py` and `_redundant_scatter_nd.py` to use `_ir_utils.same_shape` and `_ir_utils.same_dim` instead of direct equality checks or previous logic, ensuring that shape and dimension comparisons are handled consistently and correctly. [[1]](diffhunk://#diff-bd2dba53e1a4b4fb79975f7bceacf4b1c5b0b38a10d953af1e18a0b7af6c1050L85-R88) [[2]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL57-R57) [[3]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL90-R90) **Code consistency improvements:** * Standardized imports in affected files to use `_ir_utils` consistently, replacing previous aliasing or direct imports. [[1]](diffhunk://#diff-bd2dba53e1a4b4fb79975f7bceacf4b1c5b0b38a10d953af1e18a0b7af6c1050L8-R8) [[2]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL23-R23) [[3]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL44-R44) --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 28a8f56 commit 071ff1e

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

onnxscript/rewriter/_ir_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,27 @@ def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None:
152152
if dim < 0 or dim >= shape.rank():
153153
return None
154154
return shape[dim]
155+
156+
157+
def same_shape(shape1: ir.Shape | None, shape2: ir.Shape | None) -> bool:
158+
"""Check if two shapes are semantically the same."""
159+
if shape1 is None or shape2 is None:
160+
return False
161+
162+
# If any dim is unknown, the shapes are not the same
163+
if shape1.has_unknown_dim() or shape2.has_unknown_dim():
164+
return False
165+
166+
return shape1 == shape2
167+
168+
169+
def same_dim(dim1: ir.SymbolicDim | int, dim2: ir.SymbolicDim | int) -> bool:
170+
"""Check if two dimensions are semantically the same."""
171+
if type(dim1) is not type(dim2):
172+
return False
173+
if isinstance(dim1, int) and isinstance(dim2, int):
174+
return dim1 == dim2
175+
assert isinstance(dim1, ir.SymbolicDim) and isinstance(dim2, ir.SymbolicDim)
176+
if dim1.value is None or dim2.value is None:
177+
return False
178+
return dim1.value == dim2.value

onnxscript/rewriter/rules/common/_collapse_slices.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66

77
from onnxscript import ir
8-
from onnxscript.rewriter._ir_utils import is_singleton_value
8+
from onnxscript.rewriter import _ir_utils
99
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
1010

1111
logger = logging.getLogger(__name__)
@@ -82,14 +82,10 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_
8282
if data.shape is None or slice_output.shape is None:
8383
return False
8484

85-
if not is_singleton_value(steps, 1):
85+
if not _ir_utils.is_singleton_value(steps, 1):
8686
return False
8787

88-
# If any dim is unknown, the shapes are not the same
89-
if data.shape.has_unknown_dim() or slice_output.shape.has_unknown_dim():
90-
return False
91-
92-
return data.shape == slice_output.shape
88+
return _ir_utils.same_shape(data.shape, slice_output.shape)
9389

9490

9591
# Register the rewrite rules

onnxscript/rewriter/rules/common/_redundant_scatter_nd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import onnx_ir as ir
2121

2222
import onnxscript.rewriter
23-
from onnxscript.rewriter import _ir_utils as ir_utils
23+
from onnxscript.rewriter import _ir_utils
2424
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
2525

2626

@@ -41,7 +41,7 @@ def check(self, context, data, axis, transposed_data, **_):
4141
# Check that updated-indices represent the full range of the first dimension of the transposed data.
4242
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
4343
result = onnxscript.rewriter.MatchResult()
44-
axis_value = ir_utils.get_singleton_value(axis)
44+
axis_value = _ir_utils.get_singleton_value(axis)
4545
if not isinstance(axis_value, int):
4646
return result.fail("Axis value must be a constant integer.", axis)
4747
shape: ir.Shape | None = data.shape
@@ -54,7 +54,7 @@ def check(self, context, data, axis, transposed_data, **_):
5454
"Transposed data shape is not statically known.", transposed_data
5555
)
5656
actual_dim_value = transposed_data_shape[0]
57-
if updated_dim_value != actual_dim_value:
57+
if not _ir_utils.same_dim(updated_dim_value, actual_dim_value):
5858
# The first dimension of the transposed data does not match the updated dimension,
5959
# so we cannot apply this rule.
6060
return result.fail(
@@ -87,7 +87,7 @@ def check(self, context, data, indices, updates, **_):
8787
return result.fail("The value 'data' shape is not statically known.", data)
8888
if updates.shape is None:
8989
return result.fail("The value 'updates' shape is not statically known.", updates)
90-
if data.shape != updates.shape:
90+
if not _ir_utils.same_shape(data.shape, updates.shape):
9191
return result.fail(
9292
"The shape of 'data' and 'updates' are different.", [data, updates]
9393
)

0 commit comments

Comments
 (0)