Skip to content

Commit e76bfe0

Browse files
authored
[Reland] Update SplitToSequence in constant folding (#2544)
Split input (SymbolicTensor) could have no const_value, but with shape that gives us information of how many outputs an op.Split should return.
1 parent 1934901 commit e76bfe0

File tree

2 files changed

+83
-11
lines changed

2 files changed

+83
-11
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -801,27 +801,45 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
801801
axis = axis + rank
802802
if axis < 0 or axis >= rank:
803803
return None
804-
split_dimension_size = shape[axis]
805-
if not isinstance(split_dimension_size, int):
806-
return None
807804

805+
# NOTE: Split needs to either be a scalar or a 1-D tensor. We need to
806+
# calculate the number of outputs for Split.
807+
# If split is a scalar, we split into chunks of size 'split' if possible.
808+
# * the split dimension size and split_value has to be known.
809+
# If split is a 1-D tensor, we split into 'size(split)' chunks
810+
# * Get the size from split_value if it's numpy array.
811+
# * Get the size from symbolic shape if split_value is not available.
808812
split_value = _get_numpy_value(split)
809-
if split_value is None:
813+
split_shape = (
814+
split.shape.numpy() if split.shape is not None and split.shape.is_static() else None
815+
)
816+
817+
# No information about split value or shape.
818+
if split_value is None and split_shape is None:
810819
return None
811-
assert isinstance(split_value, np.ndarray)
812820

813-
if split_value.ndim == 0:
814-
# split into chunks all of size 'split' if possible.
815-
num_outputs = math.ceil(split_dimension_size / split_value.item())
821+
if isinstance(split_shape, tuple) and len(split_shape) == 1:
822+
# If split_shape is known, we can use it to determine the number of outputs.
823+
split_dimension_size = split_shape[0]
824+
assert isinstance(split_dimension_size, int)
825+
num_outputs = split_dimension_size
816826
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
817-
split_values = op.Split(
818-
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
819-
)
827+
split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
820828
elif split_value.ndim == 1:
821829
# split into 'size(split)' chunks
822830
num_outputs = split_value.size
823831
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
824832
split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
833+
elif split_value.ndim == 0:
834+
# split into chunks all of size 'split' if possible.
835+
split_dimension_size = shape[axis]
836+
if not isinstance(split_dimension_size, int):
837+
return None
838+
num_outputs = math.ceil(split_dimension_size / split_value.item())
839+
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
840+
split_values = op.Split(
841+
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
842+
)
825843
else:
826844
return None
827845

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,60 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
346346
self.assertEqual(len(optimized.graph), 7)
347347
self.assertEqual(optimized.graph[6].op_type, "Concat")
348348

349+
def test_dynamic_split_to_sequence_list_shape_rewrite(self):
350+
# split is a graph input with known 1-D static shape [4]; values unknown (not constant)
351+
# Ensures the branch: if isinstance(split_shape, tuple) and len(split_shape) == 1
352+
model = """
353+
<
354+
ir_version: 8,
355+
opset_import: ["" : 18]
356+
>
357+
func (float[2,N] x, int64[4] split) => (float[2,N] return_val) {
358+
splits = SplitToSequence <axis: int = 1> (x, split)
359+
i0 = Constant <value: tensor = int64 i0 {0}> ()
360+
s0 = SequenceAt (splits, i0)
361+
i1 = Constant <value: tensor = int64 i1 {1}> ()
362+
s1 = SequenceAt (splits, i1)
363+
i2 = Constant <value: tensor = int64 i2 {2}> ()
364+
s2 = SequenceAt (splits, i2)
365+
i3 = Constant <value: tensor = int64 i3 {3}> ()
366+
s3 = SequenceAt (splits, i3)
367+
return_val = Concat <axis: int = 1> (s0, s1, s2, s3)
368+
}"""
369+
optimized = self._fold(model)
370+
# Expect: Split + Concat (index constants & SequenceAt removed)
371+
split_nodes = [n for n in optimized.graph if n.op_type == "Split"]
372+
self.assertEqual(len(split_nodes), 1)
373+
self.assertEqual(len(split_nodes[0].outputs), 4)
374+
self.assertEqual(split_nodes[0].op_type, "Split")
375+
self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph))
376+
377+
def test_dynamic_split_to_sequence_list_shape_no_keepdims(self):
378+
# keepdims=0 path with dynamic (non-constant) splits input; triggers squeeze logic.
379+
model = """
380+
<
381+
ir_version: 8,
382+
opset_import: ["" : 18]
383+
>
384+
func (float[1,M] x, int64[3] split) => (float[1,M] return_val) {
385+
splits = SplitToSequence <axis: int = 1, keepdims: int = 0> (x, split)
386+
i0 = Constant <value: tensor = int64 i0 {0}> ()
387+
s0 = SequenceAt (splits, i0)
388+
i1 = Constant <value: tensor = int64 i1 {1}> ()
389+
s1 = SequenceAt (splits, i1)
390+
i2 = Constant <value: tensor = int64 i2 {2}> ()
391+
s2 = SequenceAt (splits, i2)
392+
return_val = Concat <axis: int = 1> (s0, s1, s2)
393+
}"""
394+
optimized = self._fold(model)
395+
split_nodes = [n for n in optimized.graph if n.op_type == "Split"]
396+
self.assertEqual(len(split_nodes), 1)
397+
self.assertEqual(len(split_nodes[0].outputs), 3)
398+
self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph))
399+
# Each split output should have a corresponding Squeeze (keepdims=0 branch)
400+
squeeze_nodes = [n for n in optimized.graph if n.op_type == "Squeeze"]
401+
self.assertEqual(len(squeeze_nodes), 3)
402+
349403
def test_initializer_input_not_folded(self):
350404
model_text = """
351405
<ir_version: 7, opset_import: [ "" : 18]>

0 commit comments

Comments
 (0)