Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
register_infer_shape,
switch,
tensor_copy,
tile,
zeros,
zeros_like,
)
Expand Down Expand Up @@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> tile(x, reps)

When the same tensor is concatenated multiple times along an axis,
replace with a single tile operation which is more efficient.

Examples
--------
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
join(1, x, x) -> tile(x, (1, 2, 1, ...))
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs

# Optimization only applies when axis is constant
if not isinstance(axis, Constant):
return None

# Extract the Python integer from the constant
axis_val = axis.data

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return

n_reps = len(tensors)
first_tensor = tensors[0]
ndim = first_tensor.ndim

# Build reps tuple to repeat only along the join axis
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
# This directly concatenates n_reps copies along axis_val
reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim))

result = tile(first_tensor, reps)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)
return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
119 changes: 112 additions & 7 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,33 +1237,138 @@ def test_local_join_1():
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.fgraph.outputs[0].dtype == config.floatX

# test we don't apply when their is 2 inputs
# test that join with 2 identical inputs now gets optimized to tile
s = join(1, a, a)
f = function([a], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1]])
assert np.all(val == [[1, 1]]) # joined along axis 1
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert len([n for n in e if isinstance(n.op, Join)]) == 0 # optimized away
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_tile():
"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.

This optimization applies whenever we concatenate the *same* tensor multiple
times along a given axis (no need for size-1 dims / ExpandDims). It replaces
the Join/concatenate with a single Tile op.
"""

# Helpers to inspect the graph
def count_join_ops(ops):
return sum(1 for n in ops if isinstance(n.op, Join))

def has_no_join(fgraph_ops):
return count_join_ops(fgraph_ops) == 0

# ---- Case 1: vector expanded to (1, n), concat along axis 0 ----
x = vector("x")
x_expanded = x[None] # (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # (3, n)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert has_no_join(ops)
# Note: Tile may be further optimized to Alloc, so we don't check for it

# ---- Case 2: matrix, concat along new leading axis ----
a = matrix("a") # (m, n)
a_expanded = a[None, :, :] # (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.array([test_mat, test_mat, test_mat, test_mat], dtype=config.floatX)
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert has_no_join(ops)

# ---- Case 3: matrix, expand along axis 1 then concat ----
a_expanded_ax1 = a[:, None, :] # (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.array(
[[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]],
dtype=config.floatX,
)
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert has_no_join(ops)

# ---- Case 4: different tensors -> should NOT optimize ----
y = vector("y")
s = join(0, x[None], y[None]) # inputs differ
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
# Join should still be present since inputs aren't identical
assert count_join_ops(ops) == 1

# ---- Case 5: plain concat without ExpandDims should now optimize ----
s = join(0, x, x, x) # (3n,)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert has_no_join(ops)

# ---- Case 6: larger repetition count ----
s = join(0, x[None], x[None], x[None], x[None], x[None]) # (5, n)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert has_no_join(ops)


def test_local_join_empty():
# Vector case
# Vector case - empty tensors should be removed
empty_vec = np.asarray([], dtype=config.floatX)
vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
assert new_s.dtype == s.dtype
# Compare to the expected form (also rewritten, since join itself gets optimized)
expected = rewrite_graph(pt.join(0, vec, vec))
assert equal_computations([new_s], [expected])

# Matrix case
# Matrix case - empty tensors should be removed
empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# Compare to the expected form (also rewritten, since join itself gets optimized)
expected = rewrite_graph(join(1, mat, mat, mat))
assert equal_computations([new_s], [expected])

# Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int)
Expand Down