Skip to content

Commit 9b699ae

Browse files
authored
Improve constant folding error messages and allow Identity to skip shape merging (#2670)
When Identity fails to merge shapes, allow the constant folder to proceed by ignoring the conflicting shape. Also improve error message to show node information if constant folding fails. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 8a7de40 commit 9b699ae

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,16 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
602602
output = node.outputs[0]
603603
if input is not None and output is not None:
604604
# NOTE: backward shape inference
605-
input.shape = _merge_shapes(input.shape, output.shape)
605+
try:
606+
input.shape = _merge_shapes(input.shape, output.shape)
607+
except Exception as e:
608+
logger.warning(
609+
"[Constant folder] Cannot merge shapes on Identity node '%s' "
610+
"(folded from: %s) because of error: %s",
611+
node.name,
612+
input.meta.get(FOLDED_FROM_KEY, set()),
613+
e,
614+
)
606615
if input.type is None:
607616
input.type = output.type
608617
state.set_sym_value(output, input)
@@ -919,7 +928,9 @@ def merge_dims(dim1, dim2):
919928
if other_shape is None:
920929
return preferred_shape
921930
if len(preferred_shape) != len(other_shape):
922-
raise ValueError("Shapes must have the same rank.")
931+
raise ValueError(
932+
f"Shapes must have the same rank, got preferred_shape={preferred_shape}, other_shape={other_shape}"
933+
)
923934
return ir.Shape(
924935
[merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)]
925936
)
@@ -1035,7 +1046,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10351046
except Exception as e:
10361047
logger.debug(
10371048
"Skipping shape inference for node %r due to exception: %s",
1038-
node.name,
1049+
node,
10391050
e,
10401051
)
10411052

@@ -1124,7 +1135,12 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11241135
for optimizer in op_optimizers:
11251136
assert optimizer
11261137
context = RewriterContext()
1127-
output = optimizer(node, context, self._state)
1138+
try:
1139+
output = optimizer(node, context, self._state)
1140+
except Exception as e:
1141+
raise RuntimeError(
1142+
f"Error during constant folding for node {node.name!r} ({node.domain}::{node.op_type})"
1143+
) from e
11281144
if output is not None:
11291145
if isinstance(output, Replacement):
11301146
return output

0 commit comments

Comments
 (0)