1111
1212from typing import ClassVar , Sequence
1313
14+ import numpy as np
15+
1416from onnxscript import ir
1517from onnxscript .rewriter import _ir_utils as ir_utils
1618from onnxscript .rewriter ._basics import MatchResult
@@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape):
123125 return op .Reshape (op .Reshape (x , shape_ignored ), shape )
124126
125127 def rewrite (self , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
126- return op .Reshape (x , shape )
128+ new_shape = op .initializer (ir .Tensor (self ._new_shape , name = shape .name ))
129+ return op .Reshape (x , new_shape , allowzero = self ._allowzero )
127130
128131 def check (self , context , x , shape_ignored , shape ) -> MatchResult :
129132 check_result = MatchResult ()
130- if shape_ignored . const_value is None :
131- return check_result . fail ( " Shape ignored is not a constant." )
132- if shape . const_value is None :
133+
134+ # Shape must be a constant.
135+ if ( np_shape := ir_utils . get_numpy_value ( shape )) is None :
133136 return check_result .fail ("Shape is not a constant." )
134- if shape .const_value .numpy ().min () <= 0 :
135- return check_result .fail ("Shape has non-positive values." )
137+ # Convert to array to support assignment destination.
138+ self ._new_shape = np .array (np_shape , np_shape .dtype )
139+
140+ # Try to replace {0,-1} values in shape if reshape output is known.
141+ if (reshape_output := context .output_values [0 ].shape ) is not None :
142+ for i , dim in enumerate (reshape_output ):
143+ if isinstance (dim , int ) and dim > 0 :
144+ self ._new_shape [i ] = dim
145+
146+ # Constraints for shape.
147+ self ._allowzero = context .nodes [0 ].attributes .get_int ("allowzero" , 0 )
148+ if self ._allowzero == 1 and any (self ._new_shape == 0 ):
149+ return check_result
150+ if any (self ._new_shape == 0 ) and any (self ._new_shape < 0 ):
151+ return check_result .fail ("Shape cannot contain both 0 and -1 dimensions." )
152+ elif np .count_nonzero (self ._new_shape == 0 ) > 1 :
153+ return check_result .fail ("Shape cannot contain more than one 0 dimension." )
154+
155+ # At this point, we can safely replace '0' with '-1'.
156+ # Note allowzero is removed since at this point it does not have any effect.
157+ self ._allowzero = None
158+ self ._new_shape = np .where (self ._new_shape == 0 , - 1 , self ._new_shape )
136159 return check_result
137160
138161
@@ -279,6 +302,55 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
279302 return check_result
280303
281304
305+ class Flatten2Reshape (RewriteRuleClassBase ):
306+ """Convert ``Flatten(x)`` to Reshape."""
307+
308+ def pattern (self , op , x : ir .Value ):
309+ return op .Flatten (x )
310+
311+ def rewrite (self , op , x : ir .Value ):
312+ new_shape = op .initializer (ir .Tensor (self ._new_shape , name = f"{ x .name } /shape" ))
313+ return op .Reshape (x , new_shape )
314+
315+ def check (self , context , x : ir .Value ) -> MatchResult :
316+ check_result = MatchResult ()
317+ self ._new_shape = np .array ([- 1 , - 1 ], "int64" )
318+
319+ # Convert axis in a positive value if possible.
320+ axis = context .root .attributes .get_int ("axis" , 1 )
321+ input_rank = None
322+ if (input_shape := x .shape ) is not None :
323+ input_rank = len (input_shape )
324+ if axis < 0 :
325+ axis += input_rank
326+
327+ # Compute reshape shape following axis attribute.
328+ if axis == 0 :
329+ self ._new_shape [0 ] = 1
330+ elif axis == 1 :
331+ self ._new_shape [0 ] = 0
332+ elif axis == input_rank :
333+ self ._new_shape [1 ] = 1
334+
335+ # Try to update shape if output is known.
336+ if (output_shape := context .output_values [0 ].shape ) is not None :
337+ for i , dim in enumerate (output_shape ):
338+ if isinstance (dim , int ):
339+ self ._new_shape [i ] = dim
340+
341+ # Try to update shape if input is known.
342+ if input_shape is not None :
343+ if all (isinstance (dim , int ) for dim in input_shape [:axis ]):
344+ self ._new_shape [0 ] = np .prod (input_shape [:axis ])
345+ if all (isinstance (dim , int ) for dim in input_shape [axis :]):
346+ self ._new_shape [1 ] = np .prod (input_shape [axis :])
347+
348+ # Verify if it is possible to apply rule.
349+ if np .count_nonzero (self ._new_shape == - 1 ) > 1 :
350+ return check_result .fail ("Impossible to compute new shape." )
351+ return check_result
352+
353+
282354# Create rule instances
283355cast_cast_rule = CastCast .rule ()
284356no_op_cast_rule = CastIdentity .rule ()
@@ -289,6 +361,7 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
289361transpose_transpose_rule = TransposeTranspose .rule ()
290362unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze .rule ()
291363squeeze_reshape_1d_rule = SqueezeReshape .rule ()
364+ flatten_to_reshape_rule = Flatten2Reshape .rule ()
292365
293366
294367def basic_optimization_rules () -> RewriteRuleSet :
@@ -311,6 +384,8 @@ def basic_optimization_rules() -> RewriteRuleSet:
311384 cast_cast_rule ,
312385 no_op_cast_rule ,
313386 no_op_expand_rule ,
387+ # flatten_to_reshape_rule is order sensitive to reshape_reshape_rule
388+ flatten_to_reshape_rule ,
314389 reshape_reshape_rule ,
315390 slice_split_rule ,
316391 no_op_transpose_rule ,
0 commit comments