diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 55fb8759d4..375bca02f5 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -547,6 +547,46 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None +@register("Split") +def split(node: ir.Node, op, _): + """Replaces Split operators with all constant inputs by a list of Constant operators.""" + # Replace single output split by Identity(x) + if len(node.outputs) == 1: + return op.Identity(node.inputs[0]) + + # Skip non-constant inputs + if (x := ir.convenience.get_const_tensor(node.inputs[0])) is None: + return None + + split_ = None + + # Option A: Sizes per split + if len(node.inputs) == 2: + # Skip non-constant splits + if (split_ := ir.convenience.get_const_tensor(node.inputs[1])) is None: + return None + # Numpy expects splits as starting indices for each section + split_ = np.cumsum(split_.numpy()[:-1]) + + # Option B: Number of (even) splits + elif (num_outputs := node.attributes.get("num_outputs")) is not None: + # Numpy accepts single integer of (even) splits as well + split_ = num_outputs.as_int() + + # Unable to determine split configuration, skip optimization + if split_ is None: + return None + + # Default split axis is 0, according to ONNX operators reference: + # https://onnx.ai/onnx/operators/onnx__Split.html + if (axis := node.attributes.get("axis")) is None: + axis = ir.Attr("axis", ir.AttributeType.INT, 0) + + # Split constant tensor and wrap a list of Constant operators + splits = np.array_split(x.numpy(), split_, axis.as_int()) + return [op.Constant(value=ir.tensor(x)) for x in splits] + + @register("Concat") def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Concat node with a single input by Identity""" diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index e58ee0ba19..93f2bfb0b0 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -403,6 +403,98 @@ def test_dropout_identity_mask(self, dropout_node: str): ops = [node.op_type for node in nodes] self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"]) + def test_split_identity_num_outputs(self): + model = """ + + agraph (float[N] x) => (float[N] z) + { + z = Split (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-1].op_type, "Identity") + + def test_split_identity_splits(self): + model = """ + + agraph (float[N] x, float[1] split) => (float[N] z) + { + z = Split (x, split) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-1].op_type, "Identity") + + + def test_split_constant_num_outputs_even(self): + model = """ + + agraph () => (float[N] z1, float[N] z2) + { + x = Constant () + z1, z2 = Split (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(len(optimized.graph[-2].outputs), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-2].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-1].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-2].op_type, "Constant") + self.assertEqual(optimized.graph[-1].op_type, "Constant") + + def test_split_constant_num_outputs_odd(self): + model = """ + + agraph () => (float[N] z1, float[M] z2) + { + x = Constant () + z1, z2 = Split (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(len(optimized.graph[-2].outputs), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-2].outputs[0].shape, [4]) + self.assertEqual(optimized.graph[-1].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-2].op_type, "Constant") + self.assertEqual(optimized.graph[-1].op_type, "Constant") + + def test_split_constant_splits(self): + model = """ + + agraph () => (float[N] z1, float[M] z2, float[L] z3, float[K] z4) + { + x = Constant () + split = Constant () + z1, z2, z3, z4 = Split (x, split) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 4) + self.assertEqual(len(optimized.graph[-3].outputs), 1) + self.assertEqual(len(optimized.graph[-2].outputs), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-4].outputs[0].shape, [2]) + self.assertEqual(optimized.graph[-3].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-2].outputs[0].shape, [1]) + self.assertEqual(optimized.graph[-1].outputs[0].shape, [1]) + self.assertEqual(optimized.graph[-4].op_type, "Constant") + self.assertEqual(optimized.graph[-3].op_type, "Constant") + self.assertEqual(optimized.graph[-2].op_type, "Constant") + self.assertEqual(optimized.graph[-1].op_type, "Constant") + def test_concat_identity(self): model = """