diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c9ade02a00..e309c9f485 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -77,6 +77,7 @@ register_infer_shape, switch, tensor_copy, + tile, zeros, zeros_like, ) @@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node): return [ret] +@register_canonicalize +@node_rewriter([Join]) +def local_join_to_repeat(fgraph, node): + """Join(axis, x, x, x, ...) -> tile(x, reps) + + When the same tensor is concatenated multiple times along an axis, + replace with a single tile operation which is more efficient. + + Examples + -------- + join(0, x, x, x) -> tile(x, (3, 1, 1, ...)) + join(1, x, x) -> tile(x, (1, 2, 1, ...)) + """ + # Extract axis and the tensors being joined + axis, *tensors = node.inputs + + # Optimization only applies when axis is constant + if not isinstance(axis, Constant): + return None + + # Extract the Python integer from the constant + axis_val = axis.data + + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return + + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return + + n_reps = len(tensors) + first_tensor = tensors[0] + ndim = first_tensor.ndim + + # Build reps tuple to repeat only along the join axis + # For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1) + # This directly concatenates n_reps copies along axis_val + reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim)) + + result = tile(first_tensor, reps) + + # Preserve debugging information + copy_stack_trace(node.outputs[0], result) + return [result] + + @register_specialize @register_canonicalize @register_useless diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index d9eb2ad7ad..d5851c6319 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1237,33 +1237,138 @@ def test_local_join_1(): assert len([n for n in e if isinstance(n.op, Join)]) == 0 assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test we don't apply when their is 2 inputs + # test that join with 2 identical inputs now gets optimized to tile s = join(1, a, a) f = function([a], s, mode=rewrite_mode) val = f([[1]]) - assert np.all(val == [[1]]) + assert np.all(val == [[1, 1]]) # joined along axis 1 e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert len([n for n in e if isinstance(n.op, Join)]) == 0 # optimized away assert f.maker.fgraph.outputs[0].dtype == config.floatX +def test_local_join_to_tile(): + """Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k. + + This optimization applies whenever we concatenate the *same* tensor multiple + times along a given axis (no need for size-1 dims / ExpandDims). It replaces + the Join/concatenate with a single Tile op. + """ + + # Helpers to inspect the graph + def count_join_ops(ops): + return sum(1 for n in ops if isinstance(n.op, Join)) + + def has_no_join(fgraph_ops): + return count_join_ops(fgraph_ops) == 0 + + # ---- Case 1: vector expanded to (1, n), concat along axis 0 ---- + x = vector("x") + x_expanded = x[None] # (1, n) + s = join(0, x_expanded, x_expanded, x_expanded) # (3, n) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) + result = f(test_val) + expected = np.array( + [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX + ) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert has_no_join(ops) + # Note: Tile may be further optimized to Alloc, so we don't check for it + + # ---- Case 2: matrix, concat along new leading axis ---- + a = matrix("a") # (m, n) + a_expanded = a[None, :, :] # (1, m, n) + s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # (4, m, n) + f = function([a], s, mode=rewrite_mode) + + test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + result = f(test_mat) + expected = np.array([test_mat, test_mat, test_mat, test_mat], dtype=config.floatX) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert has_no_join(ops) + + # ---- Case 3: matrix, expand along axis 1 then concat ---- + a_expanded_ax1 = a[:, None, :] # (m, 1, n) + s = join(1, a_expanded_ax1, a_expanded_ax1) # (m, 2, n) + f = function([a], s, mode=rewrite_mode) + + result = f(test_mat) + expected = np.array( + [[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]], + dtype=config.floatX, + ) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert has_no_join(ops) + + # ---- Case 4: different tensors -> should NOT optimize ---- + y = vector("y") + s = join(0, x[None], y[None]) # inputs differ + f = function([x, y], s, mode=rewrite_mode) + + test_vec1 = np.array([1.0, 2.0], dtype=config.floatX) + test_vec2 = np.array([3.0, 4.0], dtype=config.floatX) + result = f(test_vec1, test_vec2) + expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + # Join should still be present since inputs aren't identical + assert count_join_ops(ops) == 1 + + # ---- Case 5: plain concat without ExpandDims should now optimize ---- + s = join(0, x, x, x) # (3n,) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert has_no_join(ops) + + # ---- Case 6: larger repetition count ---- + s = join(0, x[None], x[None], x[None], x[None], x[None]) # (5, n) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert has_no_join(ops) + + def test_local_join_empty(): - # Vector case + # Vector case - empty tensors should be removed empty_vec = np.asarray([], dtype=config.floatX) vec = vector("vec") s = pt.join(0, vec, vec, empty_vec) new_s = rewrite_graph(s) - assert equal_computations([new_s], [join(0, vec, vec)]) assert new_s.dtype == s.dtype + # Compare to the expected form (also rewritten, since join itself gets optimized) + expected = rewrite_graph(pt.join(0, vec, vec)) + assert equal_computations([new_s], [expected]) - # Matrix case + # Matrix case - empty tensors should be removed empty_mat = np.zeros((2, 0), dtype=config.floatX) empty_sym_mat = matrix("m", shape=(2, 0)) mat = matrix("mat", shape=(2, 10)) s = join(1, empty_mat, mat, empty_sym_mat, mat, mat) new_s = rewrite_graph(s) - assert equal_computations([new_s], [join(1, mat, mat, mat)]) assert new_s.dtype == s.dtype + # Compare to the expected form (also rewritten, since join itself gets optimized) + expected = rewrite_graph(join(1, mat, mat, mat)) + assert equal_computations([new_s], [expected]) # Join can be completely removed, but casting and specify_shape are propagated int_mat = matrix("int_mat", dtype=int)