@@ -353,6 +353,33 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
353353 return default
354354
355355
356+ @register ("Add" )
357+ def add (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
358+ """Propagate symbolic dim values."""
359+
360+ def get_dim_value (input_index ):
361+ input = _get_input (node , input_index )
362+ if input is None :
363+ return None
364+ shape_value : ir .Shape | None = state .get_shape_value (input )
365+ if shape_value is None or len (shape_value ) != 1 :
366+ return None
367+ dim : int | ir .SymbolicDim = shape_value [0 ]
368+ return dim if isinstance (dim , int ) else dim .value
369+
370+ dim0 = get_dim_value (0 )
371+ dim1 = get_dim_value (1 )
372+ if dim0 is None or dim1 is None :
373+ return None
374+ if isinstance (dim0 , int ) and isinstance (dim1 , int ):
375+ result_dim_value : int | ir .SymbolicDim = dim0 + dim1
376+ else :
377+ result_dim_value = ir .SymbolicDim (f"{ dim0 } +{ dim1 } " )
378+ output = _get_output (node , 0 )
379+ if output is not None :
380+ state .set_sym_value (output , ir .Shape ([result_dim_value ]))
381+
382+
356383@register ("Abs" )
357384def abs (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
358385 """Replace an Abs node by Identity when applicable.
@@ -401,9 +428,26 @@ def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
401428 return None
402429
403430
431+ def _propagate_shape_value (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
432+ """Propagates symbolic shape value of input 0 to output 0.
433+
434+ Applies to ops like Reshape/Squeeze/Unsqueeze where the shape of the tensor may change
435+ but the values in the tensor remain the same.
436+ """
437+ input = _get_input (node , 0 )
438+ input_shape_value = state .get_shape_value (input )
439+ output = _get_output (node , 0 )
440+ if output is not None and input_shape_value is not None :
441+ state .set_sym_value (output , input_shape_value )
442+ return None
443+
444+
404445@register ("Reshape" )
405446def reshape (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
406- """Replace a Reshape node by Identity when applicable."""
447+ """Replace a Reshape node by Identity when applicable.
448+
449+ Also propagate symbolic shape values.
450+ """
407451 input = _get_input (node , 0 )
408452 shape = _get_input (node , 1 )
409453 if input is None or shape is None :
@@ -413,12 +457,18 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
413457 shape_value = state .get_shape_value (shape )
414458
415459 if shape_value is None or input_shape is None :
416- return None
460+ return _propagate_shape_value ( node , op , state )
417461
418462 # No need to check for special values like -1, 0, etc. here
419463 if _same_shape (input_shape , shape_value ):
420464 return op .Identity (input )
421- return None
465+ return _propagate_shape_value (node , op , state )
466+
467+
468+ @register ("Squeeze" )
469+ def squeeze (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
470+ """Propagate symbolic shape values."""
471+ return _propagate_shape_value (node , op , state )
422472
423473
424474@register ("Cast" )
0 commit comments