Skip to content
60 changes: 59 additions & 1 deletion pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
Expand Down Expand Up @@ -910,6 +910,64 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times along an axis
where it has size 1, replace with a repeat operation which is more efficient.

Examples
--------
concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0)
"""
# Extract axis and the tensors being joined
axis_sym, *tensors = node.inputs

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return None

# Extract (and normalize) axis as Python int
try:
axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True))
except NotScalarConstantError:
return None

# Get first tensor and check if ndim is known
first = tensors[0]
ndim = first.ndim
if ndim is None:
return None

# Normalize negative axes (e.g., -1 -> ndim-1)
axis_val = axis_val % ndim

# All inputs must be structurally the same tensor
# Use equal_computations to check structural equality, not symbolic ==
for t in tensors[1:]:
if not equal_computations([t], [first]):
return None

# Only apply when size along join axis is statically 1
# (e.g., x[None] has a guaranteed 1 at that axis)
shp = first.type.shape # tuple of ints/None
if shp is None or axis_val >= len(shp) or shp[axis_val] != 1:
return None

# Replace with repeat operation
from pytensor.tensor.extra_ops import repeat

n = len(tensors)
result = repeat(first, n, axis=axis_val)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)

return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
100 changes: 100 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,106 @@ def test_local_join_1():
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)

This optimization applies when joining the same tensor multiple times
along an axis where it has size 1 (e.g., after ExpandDims).
"""

# Test with vector expanded to (1, n) - concatenate along axis 0
x = vector("x")
x_expanded = x[None] # Shape: (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - add dimension and concatenate along new axis
a = matrix("a") # Shape: (m, n)
a_expanded = a[None, :, :] # Shape: (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.array([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - expand along axis 1 and concatenate
a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test that it does NOT apply when tensors are different
y = vector("y")
s = join(0, x[None], y[None])
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]])
assert np.allclose(result, expected)

# Join should still be present (not optimized)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test that it does NOT apply when tensor doesn't have size 1 along join axis
# (regular concatenation without ExpandDims)
s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (optimization doesn't apply)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x[None], x[None], x[None], x[None], x[None])
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1


def test_local_join_empty():
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
Expand Down