Skip to content

Commit 3a26097

Browse files
authored
Merge output shape with input shape instead of override (#2578)
`_constant_folding.cast` override `output.shape` with `input.shape`, that may make a static shape to dynamic shape. Here should use `_merge_shapes` instead.
1 parent 94fb24f commit 3a26097

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,7 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
501501
# should handle this. Only the optimization to eliminate redundant Cast ops
502502
# should be needed here.
503503

504-
input_shape = input.shape
505-
if input_shape is not None:
506-
output.shape = input_shape.copy()
504+
output.shape = _merge_shapes(output.shape, input.shape)
507505

508506
input_dtype = _get_input_element_type(node, 0)
509507
output_dtype = _get_int_attribute(node, "to", None)

0 commit comments

Comments
 (0)