From 5a5c0bc043e15725d116a7f66a7f325b6beeb8af Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Tue, 4 Nov 2025 00:45:08 -0800 Subject: [PATCH 1/7] Rewrite concatenate([x, x]) as repeat(x, 2) --- pytensor/tensor/rewriting/basic.py | 36 ++++++++++++- tests/tensor/rewriting/test_basic.py | 79 ++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c9ade02a00..5970f12da4 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -82,7 +82,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_arrays +from pytensor.tensor.extra_ops import broadcast_arrays, repeat from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType @@ -909,6 +909,40 @@ def local_join_make_vector(fgraph, node): copy_stack_trace(node.outputs, ret) return [ret] +@register_specialize +@register_canonicalize +@node_rewriter([Join]) +def local_join_to_repeat(fgraph, node): + """Join(axis, x, x, x, ...) -> repeat(x, n, axis) + + When the same tensor is concatenated multiple times, + replace with a single repeat operation which is more efficient. + + Examples + -------- + concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0) + """ + if not isinstance(node.op, Join): + return + + # Extract axis and the tensors being joined + axis, *tensors = node.inputs + + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return + + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return + + # Replace with repeat operation + result = repeat(tensors[0], len(tensors), axis) + + # Preserve debugging information + copy_stack_trace(node.outputs[0], result) + + return [result] @register_specialize @register_canonicalize diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index d9eb2ad7ad..40ab200a2c 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -35,6 +35,7 @@ tile, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.extra_ops import Repeat from pytensor.tensor.math import ( add, bitwise_and, @@ -1247,6 +1248,84 @@ def test_local_join_1(): assert f.maker.fgraph.outputs[0].dtype == config.floatX +def test_local_join_to_repeat(): + """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)""" + + # Test with vector - concatenate same vector 3 times along axis 0 + x = vector("x") + s = join(0, x, x, x) + f = function([x], s, mode=rewrite_mode) + + # Check numerical correctness + test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX) + assert np.allclose(result, expected) + + # Check that Join was replaced with Repeat + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + # Test with matrix - concatenate same matrix along axis 0 + a = matrix("a") + s = join(0, a, a, a, a) + f = function([a], s, mode=rewrite_mode) + + test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + result = f(test_mat) + expected = np.vstack([test_mat, test_mat, test_mat, test_mat]) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + # Test with matrix - concatenate along axis 1 + s = join(1, a, a) + f = function([a], s, mode=rewrite_mode) + + result = f(test_mat) + expected = np.hstack([test_mat, test_mat]) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + # Test that it does NOT apply when tensors are different + b = matrix("b") + s = join(0, a, b) + f = function([a, b], s, mode=rewrite_mode) + + test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX) + test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX) + result = f(test_mat1, test_mat2) + expected = np.vstack([test_mat1, test_mat2]) + assert np.allclose(result, expected) + + # Join should still be present (not optimized to Repeat) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0 + + # Test with 5 repetitions to ensure it works with larger counts + s = join(0, x, x, x, x, x) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.tile(test_val, 5) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + def test_local_join_empty(): # Vector case empty_vec = np.asarray([], dtype=config.floatX) From 44aa137d40c3cd6fde8c61b25818b8b3d00a0562 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Tue, 4 Nov 2025 00:55:59 -0800 Subject: [PATCH 2/7] fixed format --- pytensor/tensor/rewriting/basic.py | 2 ++ tests/tensor/rewriting/test_basic.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 5970f12da4..555ef72464 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -909,6 +909,7 @@ def local_join_make_vector(fgraph, node): copy_stack_trace(node.outputs, ret) return [ret] + @register_specialize @register_canonicalize @node_rewriter([Join]) @@ -944,6 +945,7 @@ def local_join_to_repeat(fgraph, node): return [result] + @register_specialize @register_canonicalize @register_useless diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 40ab200a2c..4879b816cf 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1259,7 +1259,9 @@ def test_local_join_to_repeat(): # Check numerical correctness test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) result = f(test_val) - expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX) + expected = np.array( + [1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX + ) assert np.allclose(result, expected) # Check that Join was replaced with Repeat From 58dc0276a23aceae5d4edf4b45445abae3e66d58 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Tue, 4 Nov 2025 01:11:02 -0800 Subject: [PATCH 3/7] remove register_specialize and not instance check --- pytensor/tensor/rewriting/basic.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 555ef72464..4a90af02eb 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -910,7 +910,6 @@ def local_join_make_vector(fgraph, node): return [ret] -@register_specialize @register_canonicalize @node_rewriter([Join]) def local_join_to_repeat(fgraph, node): @@ -923,9 +922,6 @@ def local_join_to_repeat(fgraph, node): -------- concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0) """ - if not isinstance(node.op, Join): - return - # Extract axis and the tensors being joined axis, *tensors = node.inputs From 9348020f065815c257242d878dabe8649583e83c Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Wed, 5 Nov 2025 01:28:27 -0800 Subject: [PATCH 4/7] Handle symbolic axis and fix test assertions for Alloc --- pytensor/tensor/rewriting/basic.py | 48 +++++++++++++---- tests/tensor/rewriting/test_basic.py | 77 +++++++++++++++++----------- 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 4a90af02eb..0297c04b07 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,7 +30,7 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph, Op -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -82,7 +82,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_arrays, repeat +from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType @@ -915,26 +915,52 @@ def local_join_make_vector(fgraph, node): def local_join_to_repeat(fgraph, node): """Join(axis, x, x, x, ...) -> repeat(x, n, axis) - When the same tensor is concatenated multiple times, - replace with a single repeat operation which is more efficient. + When the same tensor is concatenated multiple times along an axis + where it has size 1, replace with a repeat operation which is more efficient. Examples -------- - concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0) + concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0) """ # Extract axis and the tensors being joined - axis, *tensors = node.inputs + axis_sym, *tensors = node.inputs # Need at least 2 tensors to consider optimization if len(tensors) <= 1: - return + return None - # Check if all tensors are identical - if not all(t == tensors[0] for t in tensors[1:]): - return + # Extract (and normalize) axis as Python int + try: + axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True)) + except NotScalarConstantError: + return None + + # Get first tensor and check if ndim is known + first = tensors[0] + ndim = first.ndim + if ndim is None: + return None + + # Normalize negative axes (e.g., -1 -> ndim-1) + axis_val = axis_val % ndim + + # All inputs must be structurally the same tensor + # Use equal_computations to check structural equality, not symbolic == + for t in tensors[1:]: + if not equal_computations([t], [first]): + return None + + # Only apply when size along join axis is statically 1 + # (e.g., x[None] has a guaranteed 1 at that axis) + shp = first.type.shape # tuple of ints/None + if shp is None or axis_val >= len(shp) or shp[axis_val] != 1: + return None # Replace with repeat operation - result = repeat(tensors[0], len(tensors), axis) + from pytensor.tensor.extra_ops import repeat + + n = len(tensors) + result = repeat(first, n, axis=axis_val) # Preserve debugging information copy_stack_trace(node.outputs[0], result) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 4879b816cf..110656b9c6 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -35,7 +35,6 @@ tile, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.extra_ops import Repeat from pytensor.tensor.math import ( add, bitwise_and, @@ -1249,83 +1248,103 @@ def test_local_join_1(): def test_local_join_to_repeat(): - """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)""" + """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis) - # Test with vector - concatenate same vector 3 times along axis 0 + This optimization applies when joining the same tensor multiple times + along an axis where it has size 1 (e.g., after ExpandDims). + """ + + # Test with vector expanded to (1, n) - concatenate along axis 0 x = vector("x") - s = join(0, x, x, x) + x_expanded = x[None] # Shape: (1, n) + s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n) f = function([x], s, mode=rewrite_mode) # Check numerical correctness test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) result = f(test_val) expected = np.array( - [1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX + [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX ) assert np.allclose(result, expected) - # Check that Join was replaced with Repeat + # Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc) ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 - # Test with matrix - concatenate same matrix along axis 0 - a = matrix("a") - s = join(0, a, a, a, a) + # Test with matrix - add dimension and concatenate along new axis + a = matrix("a") # Shape: (m, n) + a_expanded = a[None, :, :] # Shape: (1, m, n) + s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n) f = function([a], s, mode=rewrite_mode) test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) result = f(test_mat) - expected = np.vstack([test_mat, test_mat, test_mat, test_mat]) + expected = np.array([test_mat, test_mat, test_mat, test_mat]) assert np.allclose(result, expected) # Check optimization applied ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 - # Test with matrix - concatenate along axis 1 - s = join(1, a, a) + # Test with matrix - expand along axis 1 and concatenate + a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n) + s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n) f = function([a], s, mode=rewrite_mode) result = f(test_mat) - expected = np.hstack([test_mat, test_mat]) + expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]]) assert np.allclose(result, expected) # Check optimization applied ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 # Test that it does NOT apply when tensors are different - b = matrix("b") - s = join(0, a, b) - f = function([a, b], s, mode=rewrite_mode) - - test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX) - test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX) - result = f(test_mat1, test_mat2) - expected = np.vstack([test_mat1, test_mat2]) + y = vector("y") + s = join(0, x[None], y[None]) + f = function([x, y], s, mode=rewrite_mode) + + test_vec1 = np.array([1.0, 2.0], dtype=config.floatX) + test_vec2 = np.array([3.0, 4.0], dtype=config.floatX) + result = f(test_vec1, test_vec2) + expected = np.array([[1.0, 2.0], [3.0, 4.0]]) + assert np.allclose(result, expected) + + # Join should still be present (not optimized) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + + # Test that it does NOT apply when tensor doesn't have size 1 along join axis + # (regular concatenation without ExpandDims) + s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX) assert np.allclose(result, expected) - # Join should still be present (not optimized to Repeat) + # Join should still be present (optimization doesn't apply) ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 1 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0 # Test with 5 repetitions to ensure it works with larger counts - s = join(0, x, x, x, x, x) + s = join(0, x[None], x[None], x[None], x[None], x[None]) f = function([x], s, mode=rewrite_mode) test_val = np.array([1.0, 2.0], dtype=config.floatX) result = f(test_val) - expected = np.tile(test_val, 5) + expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX) assert np.allclose(result, expected) # Check optimization applied ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 def test_local_join_empty(): From 63dbcf8a80a355c89628f0564aed9f9445ed6be1 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Wed, 5 Nov 2025 12:41:44 -0800 Subject: [PATCH 5/7] removed equal computation --- pytensor/tensor/rewriting/basic.py | 46 +++++++++++------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 0297c04b07..76d1ac8aba 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,7 +30,7 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph, Op -from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -923,44 +923,32 @@ def local_join_to_repeat(fgraph, node): concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0) """ # Extract axis and the tensors being joined - axis_sym, *tensors = node.inputs + axis, *tensors = node.inputs - # Need at least 2 tensors to consider optimization - if len(tensors) <= 1: + # Optimization only applies when axis is constant + if not isinstance(axis, Constant): return None - # Extract (and normalize) axis as Python int - try: - axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True)) - except NotScalarConstantError: - return None + # Extract the Python integer from the constant + axis_val = axis.data - # Get first tensor and check if ndim is known - first = tensors[0] - ndim = first.ndim - if ndim is None: - return None - - # Normalize negative axes (e.g., -1 -> ndim-1) - axis_val = axis_val % ndim + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return - # All inputs must be structurally the same tensor - # Use equal_computations to check structural equality, not symbolic == - for t in tensors[1:]: - if not equal_computations([t], [first]): - return None + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return - # Only apply when size along join axis is statically 1 - # (e.g., x[None] has a guaranteed 1 at that axis) - shp = first.type.shape # tuple of ints/None - if shp is None or axis_val >= len(shp) or shp[axis_val] != 1: - return None + # Only optimize if the tensor has size 1 along the join axis + first_tensor = tensors[0] + if first_tensor.type.shape[axis_val] != 1: + return # Replace with repeat operation from pytensor.tensor.extra_ops import repeat - n = len(tensors) - result = repeat(first, n, axis=axis_val) + result = repeat(first_tensor, len(tensors), axis_val) # Preserve debugging information copy_stack_trace(node.outputs[0], result) From e0cf01b32877acd84a83b0f037187676c21ceca0 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Sat, 8 Nov 2025 02:10:17 -0800 Subject: [PATCH 6/7] replace repeat with tile function --- pytensor/tensor/rewriting/basic.py | 24 +++--- tests/tensor/rewriting/test_basic.py | 121 ++++++++++++++++----------- 2 files changed, 86 insertions(+), 59 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 76d1ac8aba..e309c9f485 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -77,6 +77,7 @@ register_infer_shape, switch, tensor_copy, + tile, zeros, zeros_like, ) @@ -913,14 +914,15 @@ def local_join_make_vector(fgraph, node): @register_canonicalize @node_rewriter([Join]) def local_join_to_repeat(fgraph, node): - """Join(axis, x, x, x, ...) -> repeat(x, n, axis) + """Join(axis, x, x, x, ...) -> tile(x, reps) - When the same tensor is concatenated multiple times along an axis - where it has size 1, replace with a repeat operation which is more efficient. + When the same tensor is concatenated multiple times along an axis, + replace with a single tile operation which is more efficient. Examples -------- - concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0) + join(0, x, x, x) -> tile(x, (3, 1, 1, ...)) + join(1, x, x) -> tile(x, (1, 2, 1, ...)) """ # Extract axis and the tensors being joined axis, *tensors = node.inputs @@ -940,19 +942,19 @@ def local_join_to_repeat(fgraph, node): if not all(t == tensors[0] for t in tensors[1:]): return - # Only optimize if the tensor has size 1 along the join axis + n_reps = len(tensors) first_tensor = tensors[0] - if first_tensor.type.shape[axis_val] != 1: - return + ndim = first_tensor.ndim - # Replace with repeat operation - from pytensor.tensor.extra_ops import repeat + # Build reps tuple to repeat only along the join axis + # For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1) + # This directly concatenates n_reps copies along axis_val + reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim)) - result = repeat(first_tensor, len(tensors), axis_val) + result = tile(first_tensor, reps) # Preserve debugging information copy_stack_trace(node.outputs[0], result) - return [result] diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 110656b9c6..dd8676cd83 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1237,30 +1237,37 @@ def test_local_join_1(): assert len([n for n in e if isinstance(n.op, Join)]) == 0 assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test we don't apply when their is 2 inputs + # test that join with 2 identical inputs now gets optimized to tile s = join(1, a, a) f = function([a], s, mode=rewrite_mode) val = f([[1]]) - assert np.all(val == [[1]]) + assert np.all(val == [[1, 1]]) # joined along axis 1 e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert len([n for n in e if isinstance(n.op, Join)]) == 0 # optimized away assert f.maker.fgraph.outputs[0].dtype == config.floatX -def test_local_join_to_repeat(): - """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis) +def test_local_join_to_tile(): + """Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k. - This optimization applies when joining the same tensor multiple times - along an axis where it has size 1 (e.g., after ExpandDims). + This optimization applies whenever we concatenate the *same* tensor multiple + times along a given axis (no need for size-1 dims / ExpandDims). It replaces + the Join/concatenate with a single Tile op. """ - # Test with vector expanded to (1, n) - concatenate along axis 0 + # Helpers to inspect the graph without depending on concrete Op classes + def count_op(ops, cls_name): + return sum(1 for n in ops if n.op.__class__.__name__ == cls_name) + + def has_no_join(fgraph_ops): + return count_op(fgraph_ops, "Join") == 0 + + # ---- Case 1: vector expanded to (1, n), concat along axis 0 ---- x = vector("x") - x_expanded = x[None] # Shape: (1, n) - s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n) + x_expanded = x[None] # (1, n) + s = join(0, x_expanded, x_expanded, x_expanded) # (3, n) f = function([x], s, mode=rewrite_mode) - # Check numerical correctness test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) result = f(test_val) expected = np.array( @@ -1268,59 +1275,56 @@ def test_local_join_to_repeat(): ) assert np.allclose(result, expected) - # Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc) ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + assert has_no_join(ops) + # Note: Tile may be further optimized to Alloc, so we don't check for it - # Test with matrix - add dimension and concatenate along new axis - a = matrix("a") # Shape: (m, n) - a_expanded = a[None, :, :] # Shape: (1, m, n) - s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n) + # ---- Case 2: matrix, concat along new leading axis ---- + a = matrix("a") # (m, n) + a_expanded = a[None, :, :] # (1, m, n) + s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # (4, m, n) f = function([a], s, mode=rewrite_mode) test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) result = f(test_mat) - expected = np.array([test_mat, test_mat, test_mat, test_mat]) + expected = np.array([test_mat, test_mat, test_mat, test_mat], dtype=config.floatX) assert np.allclose(result, expected) - # Check optimization applied ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + assert has_no_join(ops) - # Test with matrix - expand along axis 1 and concatenate - a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n) - s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n) + # ---- Case 3: matrix, expand along axis 1 then concat ---- + a_expanded_ax1 = a[:, None, :] # (m, 1, n) + s = join(1, a_expanded_ax1, a_expanded_ax1) # (m, 2, n) f = function([a], s, mode=rewrite_mode) result = f(test_mat) - expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]]) + expected = np.array( + [[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]], + dtype=config.floatX, + ) assert np.allclose(result, expected) - # Check optimization applied ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + assert has_no_join(ops) - # Test that it does NOT apply when tensors are different + # ---- Case 4: different tensors -> should NOT optimize ---- y = vector("y") - s = join(0, x[None], y[None]) + s = join(0, x[None], y[None]) # inputs differ f = function([x, y], s, mode=rewrite_mode) test_vec1 = np.array([1.0, 2.0], dtype=config.floatX) test_vec2 = np.array([3.0, 4.0], dtype=config.floatX) result = f(test_vec1, test_vec2) - expected = np.array([[1.0, 2.0], [3.0, 4.0]]) + expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) assert np.allclose(result, expected) - # Join should still be present (not optimized) ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + # Join should still be present since inputs aren't identical + assert count_op(ops, "Join") == 1 - # Test that it does NOT apply when tensor doesn't have size 1 along join axis - # (regular concatenation without ExpandDims) - s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims + # ---- Case 5: plain concat without ExpandDims should now optimize ---- + s = join(0, x, x, x) # (3n,) f = function([x], s, mode=rewrite_mode) test_val = np.array([1.0, 2.0], dtype=config.floatX) @@ -1328,12 +1332,11 @@ def test_local_join_to_repeat(): expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX) assert np.allclose(result, expected) - # Join should still be present (optimization doesn't apply) ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + assert has_no_join(ops) - # Test with 5 repetitions to ensure it works with larger counts - s = join(0, x[None], x[None], x[None], x[None], x[None]) + # ---- Case 6: larger repetition count ---- + s = join(0, x[None], x[None], x[None], x[None], x[None]) # (5, n) f = function([x], s, mode=rewrite_mode) test_val = np.array([1.0, 2.0], dtype=config.floatX) @@ -1341,29 +1344,51 @@ def test_local_join_to_repeat(): expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX) assert np.allclose(result, expected) - # Check optimization applied ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + assert has_no_join(ops) def test_local_join_empty(): - # Vector case + # Vector case - empty tensors should be removed and join optimized empty_vec = np.asarray([], dtype=config.floatX) vec = vector("vec") s = pt.join(0, vec, vec, empty_vec) new_s = rewrite_graph(s) - assert equal_computations([new_s], [join(0, vec, vec)]) + # Verify dtype is preserved assert new_s.dtype == s.dtype + # Verify no Join in the optimized graph + f = function([vec], new_s, mode=rewrite_mode) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + # Verify numerical correctness + test_vec = np.array([1.0, 2.0, 3.0], dtype=config.floatX) + result = f(test_vec) + expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX) + assert np.allclose(result, expected) - # Matrix case + # Matrix case - empty tensors should be removed and join optimized empty_mat = np.zeros((2, 0), dtype=config.floatX) empty_sym_mat = matrix("m", shape=(2, 0)) mat = matrix("mat", shape=(2, 10)) s = join(1, empty_mat, mat, empty_sym_mat, mat, mat) new_s = rewrite_graph(s) - assert equal_computations([new_s], [join(1, mat, mat, mat)]) + # Verify dtype is preserved assert new_s.dtype == s.dtype + # Verify no Join in the optimized graph + f = function([mat], new_s, mode=rewrite_mode) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + # Verify numerical correctness + test_mat = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + ], + dtype=config.floatX, + ) + result = f(test_mat) + expected = np.concatenate([test_mat, test_mat, test_mat], axis=1) + assert np.allclose(result, expected) # Join can be completely removed, but casting and specify_shape are propagated int_mat = matrix("int_mat", dtype=int) From 37c35c8b06fd9c596c11924c4b2108d107579d4b Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Sat, 8 Nov 2025 16:43:08 -0800 Subject: [PATCH 7/7] used specific instance check and removed local compute --- tests/tensor/rewriting/test_basic.py | 46 ++++++++-------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index dd8676cd83..d5851c6319 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1255,12 +1255,12 @@ def test_local_join_to_tile(): the Join/concatenate with a single Tile op. """ - # Helpers to inspect the graph without depending on concrete Op classes - def count_op(ops, cls_name): - return sum(1 for n in ops if n.op.__class__.__name__ == cls_name) + # Helpers to inspect the graph + def count_join_ops(ops): + return sum(1 for n in ops if isinstance(n.op, Join)) def has_no_join(fgraph_ops): - return count_op(fgraph_ops, "Join") == 0 + return count_join_ops(fgraph_ops) == 0 # ---- Case 1: vector expanded to (1, n), concat along axis 0 ---- x = vector("x") @@ -1321,7 +1321,7 @@ def has_no_join(fgraph_ops): ops = f.maker.fgraph.toposort() # Join should still be present since inputs aren't identical - assert count_op(ops, "Join") == 1 + assert count_join_ops(ops) == 1 # ---- Case 5: plain concat without ExpandDims should now optimize ---- s = join(0, x, x, x) # (3n,) @@ -1349,46 +1349,26 @@ def has_no_join(fgraph_ops): def test_local_join_empty(): - # Vector case - empty tensors should be removed and join optimized + # Vector case - empty tensors should be removed empty_vec = np.asarray([], dtype=config.floatX) vec = vector("vec") s = pt.join(0, vec, vec, empty_vec) new_s = rewrite_graph(s) - # Verify dtype is preserved assert new_s.dtype == s.dtype - # Verify no Join in the optimized graph - f = function([vec], new_s, mode=rewrite_mode) - ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - # Verify numerical correctness - test_vec = np.array([1.0, 2.0, 3.0], dtype=config.floatX) - result = f(test_vec) - expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX) - assert np.allclose(result, expected) + # Compare to the expected form (also rewritten, since join itself gets optimized) + expected = rewrite_graph(pt.join(0, vec, vec)) + assert equal_computations([new_s], [expected]) - # Matrix case - empty tensors should be removed and join optimized + # Matrix case - empty tensors should be removed empty_mat = np.zeros((2, 0), dtype=config.floatX) empty_sym_mat = matrix("m", shape=(2, 0)) mat = matrix("mat", shape=(2, 10)) s = join(1, empty_mat, mat, empty_sym_mat, mat, mat) new_s = rewrite_graph(s) - # Verify dtype is preserved assert new_s.dtype == s.dtype - # Verify no Join in the optimized graph - f = function([mat], new_s, mode=rewrite_mode) - ops = f.maker.fgraph.toposort() - assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - # Verify numerical correctness - test_mat = np.array( - [ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], - ], - dtype=config.floatX, - ) - result = f(test_mat) - expected = np.concatenate([test_mat, test_mat, test_mat], axis=1) - assert np.allclose(result, expected) + # Compare to the expected form (also rewritten, since join itself gets optimized) + expected = rewrite_graph(join(1, mat, mat, mat)) + assert equal_computations([new_s], [expected]) # Join can be completely removed, but casting and specify_shape are propagated int_mat = matrix("int_mat", dtype=int)