@@ -346,6 +346,60 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
346346 self .assertEqual (len (optimized .graph ), 7 )
347347 self .assertEqual (optimized .graph [6 ].op_type , "Concat" )
348348
349+ def test_dynamic_split_to_sequence_list_shape_rewrite (self ):
350+ # split is a graph input with known 1-D static shape [4]; values unknown (not constant)
351+ # Ensures the branch: if isinstance(split_shape, tuple) and len(split_shape) == 1
352+ model = """
353+ <
354+ ir_version: 8,
355+ opset_import: ["" : 18]
356+ >
357+ func (float[2,N] x, int64[4] split) => (float[2,N] return_val) {
358+ splits = SplitToSequence <axis: int = 1> (x, split)
359+ i0 = Constant <value: tensor = int64 i0 {0}> ()
360+ s0 = SequenceAt (splits, i0)
361+ i1 = Constant <value: tensor = int64 i1 {1}> ()
362+ s1 = SequenceAt (splits, i1)
363+ i2 = Constant <value: tensor = int64 i2 {2}> ()
364+ s2 = SequenceAt (splits, i2)
365+ i3 = Constant <value: tensor = int64 i3 {3}> ()
366+ s3 = SequenceAt (splits, i3)
367+ return_val = Concat <axis: int = 1> (s0, s1, s2, s3)
368+ }"""
369+ optimized = self ._fold (model )
370+ # Expect: Split + Concat (index constants & SequenceAt removed)
371+ split_nodes = [n for n in optimized .graph if n .op_type == "Split" ]
372+ self .assertEqual (len (split_nodes ), 1 )
373+ self .assertEqual (len (split_nodes [0 ].outputs ), 4 )
374+ self .assertEqual (split_nodes [0 ].op_type , "Split" )
375+ self .assertTrue (all (n .op_type != "SequenceAt" for n in optimized .graph ))
376+
377+ def test_dynamic_split_to_sequence_list_shape_no_keepdims (self ):
378+ # keepdims=0 path with dynamic (non-constant) splits input; triggers squeeze logic.
379+ model = """
380+ <
381+ ir_version: 8,
382+ opset_import: ["" : 18]
383+ >
384+ func (float[1,M] x, int64[3] split) => (float[1,M] return_val) {
385+ splits = SplitToSequence <axis: int = 1, keepdims: int = 0> (x, split)
386+ i0 = Constant <value: tensor = int64 i0 {0}> ()
387+ s0 = SequenceAt (splits, i0)
388+ i1 = Constant <value: tensor = int64 i1 {1}> ()
389+ s1 = SequenceAt (splits, i1)
390+ i2 = Constant <value: tensor = int64 i2 {2}> ()
391+ s2 = SequenceAt (splits, i2)
392+ return_val = Concat <axis: int = 1> (s0, s1, s2)
393+ }"""
394+ optimized = self ._fold (model )
395+ split_nodes = [n for n in optimized .graph if n .op_type == "Split" ]
396+ self .assertEqual (len (split_nodes ), 1 )
397+ self .assertEqual (len (split_nodes [0 ].outputs ), 3 )
398+ self .assertTrue (all (n .op_type != "SequenceAt" for n in optimized .graph ))
399+ # Each split output should have a corresponding Squeeze (keepdims=0 branch)
400+ squeeze_nodes = [n for n in optimized .graph if n .op_type == "Squeeze" ]
401+ self .assertEqual (len (squeeze_nodes ), 3 )
402+
349403 def test_initializer_input_not_folded (self ):
350404 model_text = """
351405 <ir_version: 7, opset_import: [ "" : 18]>
0 commit comments