Skip to content

Commit 0766199

Browse files
Improve symbolic dim tracking (#2520)
A few improvements to symbolic dim tracking (for better fusion in Gemma3). * Track symbolic dimension additions * Propagate symbolic dims through Reshapes/Squeeze which show up when converting them to and from 0d or 1d tensors. * Enables elimination of superfluous "Abs" applies to symbolic shapes (before an "Expand") --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent bf1c139 commit 0766199

File tree

1 file changed

+53
-3
lines changed

1 file changed

+53
-3
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,33 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
353353
return default
354354

355355

356+
@register("Add")
357+
def add(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
358+
"""Propagate symbolic dim values."""
359+
360+
def get_dim_value(input_index):
361+
input = _get_input(node, input_index)
362+
if input is None:
363+
return None
364+
shape_value: ir.Shape | None = state.get_shape_value(input)
365+
if shape_value is None or len(shape_value) != 1:
366+
return None
367+
dim: int | ir.SymbolicDim = shape_value[0]
368+
return dim if isinstance(dim, int) else dim.value
369+
370+
dim0 = get_dim_value(0)
371+
dim1 = get_dim_value(1)
372+
if dim0 is None or dim1 is None:
373+
return None
374+
if isinstance(dim0, int) and isinstance(dim1, int):
375+
result_dim_value: int | ir.SymbolicDim = dim0 + dim1
376+
else:
377+
result_dim_value = ir.SymbolicDim(f"{dim0}+{dim1}")
378+
output = _get_output(node, 0)
379+
if output is not None:
380+
state.set_sym_value(output, ir.Shape([result_dim_value]))
381+
382+
356383
@register("Abs")
357384
def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
358385
"""Replace an Abs node by Identity when applicable.
@@ -401,9 +428,26 @@ def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
401428
return None
402429

403430

431+
def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
432+
"""Propagates symbolic shape value of input 0 to output 0.
433+
434+
Applies to ops like Reshape/Squeeze/Unsqueeze where the shape of the tensor may change
435+
but the values in the tensor remain the same.
436+
"""
437+
input = _get_input(node, 0)
438+
input_shape_value = state.get_shape_value(input)
439+
output = _get_output(node, 0)
440+
if output is not None and input_shape_value is not None:
441+
state.set_sym_value(output, input_shape_value)
442+
return None
443+
444+
404445
@register("Reshape")
405446
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
406-
"""Replace a Reshape node by Identity when applicable."""
447+
"""Replace a Reshape node by Identity when applicable.
448+
449+
Also propagate symbolic shape values.
450+
"""
407451
input = _get_input(node, 0)
408452
shape = _get_input(node, 1)
409453
if input is None or shape is None:
@@ -413,12 +457,18 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
413457
shape_value = state.get_shape_value(shape)
414458

415459
if shape_value is None or input_shape is None:
416-
return None
460+
return _propagate_shape_value(node, op, state)
417461

418462
# No need to check for special values like -1, 0, etc. here
419463
if _same_shape(input_shape, shape_value):
420464
return op.Identity(input)
421-
return None
465+
return _propagate_shape_value(node, op, state)
466+
467+
468+
@register("Squeeze")
469+
def squeeze(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
470+
"""Propagate symbolic shape values."""
471+
return _propagate_shape_value(node, op, state)
422472

423473

424474
@register("Cast")

0 commit comments

Comments
 (0)