From d7344fa5c1de56331f55a2d2fc14977cea2e727f Mon Sep 17 00:00:00 2001 From: Eby Elanjikal Date: Thu, 16 Oct 2025 19:23:18 +0530 Subject: [PATCH 1/3] WIP: Add rewrite to fuse nested BlockDiag Ops From 786b7939851fd06a9cd3c0727a4fce7413a7e27e Mon Sep 17 00:00:00 2001 From: Eby Elanjikal Date: Thu, 16 Oct 2025 22:27:10 +0530 Subject: [PATCH 2/3] Add fuse_blockdiagonal rewrite and corresponding test for nested BlockDiagonal --- pytensor/tensor/rewriting/linalg.py | 25 ++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 43 +++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 17a3ce9165..9b51f0593d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -60,11 +60,36 @@ solve_triangular, ) +from pytensor.tensor.slinalg import BlockDiagonal logger = logging.getLogger(__name__) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) +from pytensor.tensor.slinalg import BlockDiagonal +from pytensor.graph import Apply + +def fuse_blockdiagonal(node): + # Only process if this node is a BlockDiagonal + if not isinstance(node.owner.op, BlockDiagonal): + return node + + new_inputs = [] + changed = False + for inp in node.owner.inputs: + # If input is itself a BlockDiagonal, flatten its inputs + if inp.owner and isinstance(inp.owner.op, BlockDiagonal): + new_inputs.extend(inp.owner.inputs) + changed = True + else: + new_inputs.append(inp) + + if changed: + # Return a new fused BlockDiagonal with all inputs + return BlockDiagonal(len(new_inputs))(*new_inputs) + return node + + def is_matrix_transpose(x: TensorVariable) -> bool: """Check if a variable corresponds to a transpose of the last two axes""" node = x.owner diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 515120e446..d426f1a039 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -43,7 +43,50 @@ from tests import unittest_tools as utt from tests.test_rop import break_op +from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal + +def test_nested_blockdiag_fusion(): + # Create matrix variables + x = pt.matrix("x") + y = pt.matrix("y") + z = pt.matrix("z") + + # Nested BlockDiagonal + inner = BlockDiagonal(2)(x, y) + outer = BlockDiagonal(2)(inner, z) + + # Count number of BlockDiagonal ops before fusion + nodes_before = ancestors([outer]) + initial_count = sum( + 1 for node in nodes_before + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" + + # Apply the rewrite + fused = fuse_blockdiagonal(outer) + + # Count number of BlockDiagonal ops after fusion + nodes_after = ancestors([fused]) + fused_count = sum( + 1 for node in nodes_after + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert fused_count == 1, "Nested BlockDiagonal ops were not fused" + + # Check that all original inputs are preserved + fused_inputs = [ + inp + for node in ancestors([fused]) + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + for inp in node.owner.inputs + ] + assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" + + + + def test_matrix_inverse_rop_lop(): rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx") From bec4bd368c301d8a5ded5c669b5ee7ad62c1d663 Mon Sep 17 00:00:00 2001 From: Eby Elanjikal Date: Tue, 4 Nov 2025 22:52:26 +0530 Subject: [PATCH 3/3] linalg: fuse nested BlockDiagonal ops and add corresponding tests --- pytensor/tensor/rewriting/linalg.py | 25 ++++----- tests/tensor/rewriting/test_linalg.py | 78 +++++++++++++++++---------- 2 files changed, 63 insertions(+), 40 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 9b51f0593d..3960a396cf 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -60,24 +60,23 @@ solve_triangular, ) -from pytensor.tensor.slinalg import BlockDiagonal logger = logging.getLogger(__name__) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) -from pytensor.tensor.slinalg import BlockDiagonal -from pytensor.graph import Apply +@register_canonicalize +@node_rewriter([BlockDiagonal]) +def fuse_blockdiagonal(fgraph, node): + """Fuse nested BlockDiagonal ops into a single BlockDiagonal.""" -def fuse_blockdiagonal(node): - # Only process if this node is a BlockDiagonal - if not isinstance(node.owner.op, BlockDiagonal): - return node + if not isinstance(node.op, BlockDiagonal): + return None new_inputs = [] changed = False - for inp in node.owner.inputs: - # If input is itself a BlockDiagonal, flatten its inputs + + for inp in node.inputs: if inp.owner and isinstance(inp.owner.op, BlockDiagonal): new_inputs.extend(inp.owner.inputs) changed = True @@ -85,9 +84,11 @@ def fuse_blockdiagonal(node): new_inputs.append(inp) if changed: - # Return a new fused BlockDiagonal with all inputs - return BlockDiagonal(len(new_inputs))(*new_inputs) - return node + fused_op = BlockDiagonal(len(new_inputs)) + new_output = fused_op(*new_inputs) + return [new_output] + + return None def is_matrix_transpose(x: TensorVariable) -> bool: diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index d426f1a039..cd098bed25 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -43,50 +43,72 @@ from tests import unittest_tools as utt from tests.test_rop import break_op -from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal - def test_nested_blockdiag_fusion(): - # Create matrix variables - x = pt.matrix("x") - y = pt.matrix("y") - z = pt.matrix("z") + x = pt.tensor("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=(3, 3)) - # Nested BlockDiagonal - inner = BlockDiagonal(2)(x, y) + inner = BlockDiagonal(2)(x, y) outer = BlockDiagonal(2)(inner, z) - # Count number of BlockDiagonal ops before fusion nodes_before = ancestors([outer]) initial_count = sum( - 1 for node in nodes_before + 1 + for node in nodes_before if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) ) - assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" + assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops" - # Apply the rewrite - fused = fuse_blockdiagonal(outer) + f = pytensor.function([x, y, z], outer) + fgraph = f.maker.fgraph - # Count number of BlockDiagonal ops after fusion - nodes_after = ancestors([fused]) - fused_count = sum( - 1 for node in nodes_after - if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) - ) - assert fused_count == 1, "Nested BlockDiagonal ops were not fused" + nodes_after = fgraph.apply_nodes + fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)] + assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused" - # Check that all original inputs are preserved - fused_inputs = [ - inp - for node in ancestors([fused]) - if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) - for inp in node.owner.inputs + fused_op = fused_nodes[0].op + + assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}" + + out_shape = fgraph.outputs[0].type.shape + assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}" + + +def test_deeply_nested_blockdiag_fusion(): + x = pt.tensor("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=(3, 3)) + w = pt.tensor("w", shape=(3, 3)) + + inner1 = BlockDiagonal(2)(x, y) + inner2 = BlockDiagonal(2)(inner1, z) + outer = BlockDiagonal(2)(inner2, w) + + f = pytensor.function([x, y, z, w], outer) + fgraph = f.maker.fgraph + + fused_nodes = [ + node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal) ] - assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" + assert len(fused_nodes) == 1, ( + f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}" + ) + + fused_op = fused_nodes[0].op + + assert fused_op.n_inputs == 4, ( + f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}" + ) + + out_shape = fgraph.outputs[0].type.shape + expected_shape = (12, 12) # 4 blocks of (3x3) + assert out_shape == expected_shape, ( + f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}" + ) - def test_matrix_inverse_rop_lop(): rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx")