Skip to content

Commit 9568a83

Browse files
Implement pack/unpack helpers
1 parent 1f9a67b commit 9568a83

File tree

2 files changed

+143
-3
lines changed

2 files changed

+143
-3
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.scalar import upcast
2626
from pytensor.tensor import TensorLike, as_tensor_variable
2727
from pytensor.tensor import basic as ptb
28-
from pytensor.tensor.basic import alloc, join, second
28+
from pytensor.tensor.basic import alloc, arange, join, second
2929
from pytensor.tensor.exceptions import NotScalarConstantError
3030
from pytensor.tensor.math import abs as pt_abs
3131
from pytensor.tensor.math import all as pt_all
@@ -45,7 +45,7 @@
4545
from pytensor.tensor.math import sum as pt_sum
4646
from pytensor.tensor.shape import Shape_i
4747
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
48-
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
48+
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
4949
from pytensor.tensor.utils import normalize_reduce_axis
5050
from pytensor.tensor.variable import TensorVariable
5151
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
@@ -2011,6 +2011,73 @@ def concat_with_broadcast(tensor_list, axis=0):
20112011
return join(axis, *bcast_tensor_inputs)
20122012

20132013

2014+
def pack(
2015+
*tensors: TensorVariable,
2016+
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
2017+
"""
2018+
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
2019+
2020+
Parameters
2021+
----------
2022+
tensors: TensorVariable
2023+
Tensors to be packed into a single vector.
2024+
2025+
Returns
2026+
-------
2027+
flat_tensor: TensorVariable
2028+
A new symbolic variable representing the concatenated 1d vector of all tensor inputs
2029+
packed_shapes: list of tuples of TensorVariable
2030+
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
2031+
"""
2032+
if not tensors:
2033+
raise ValueError("Cannot pack an empty list of tensors.")
2034+
2035+
# Get the shapes of the input tensors
2036+
packed_shapes = [
2037+
t.type.shape if not any(s is None for s in t.type.shape) else t.shape
2038+
for t in tensors
2039+
]
2040+
2041+
# Flatten each tensor and concatenate them into a single 1D vector
2042+
flat_tensor = join(0, *[t.ravel() for t in tensors])
2043+
2044+
return flat_tensor, packed_shapes
2045+
2046+
2047+
def unpack(
2048+
flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]]
2049+
) -> tuple[TensorVariable, ...]:
2050+
"""
2051+
Unpack a flat tensor into its original shapes based on the provided packed shapes.
2052+
2053+
Parameters
2054+
----------
2055+
flat_tensor: TensorVariable
2056+
A 1D tensor that contains the concatenated values of the original tensors.
2057+
packed_shapes: list of tuples of TensorVariable
2058+
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
2059+
2060+
Returns
2061+
-------
2062+
unpacked_tensors: tuple of TensorVariable
2063+
A tuple containing the unpacked tensors with their original shapes.
2064+
"""
2065+
if not packed_shapes:
2066+
raise ValueError("Cannot unpack an empty list of shapes.")
2067+
2068+
start = 0
2069+
unpacked_tensors = []
2070+
for shape in packed_shapes:
2071+
size = prod(shape, no_zeros_in_input=True)
2072+
end = start + size
2073+
unpacked_tensors.append(
2074+
take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape)
2075+
)
2076+
start = end
2077+
2078+
return tuple(unpacked_tensors)
2079+
2080+
20142081
__all__ = [
20152082
"bartlett",
20162083
"bincount",
@@ -2033,4 +2100,6 @@ def concat_with_broadcast(tensor_list, axis=0):
20332100
"squeeze",
20342101
"unique",
20352102
"unravel_index",
2103+
"pack",
2104+
"unpack",
20362105
]

tests/tensor/test_extra_ops.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
1111
from pytensor.graph import rewrite_graph
12-
from pytensor.graph.basic import Constant, equal_computations
12+
from pytensor.graph.basic import Constant, Variable, equal_computations
1313
from pytensor.graph.traversal import applys_between
1414
from pytensor.npy_2_compat import old_np_unique
1515
from pytensor.raise_op import Assert
@@ -38,11 +38,13 @@
3838
diff,
3939
fill_diagonal,
4040
fill_diagonal_offset,
41+
pack,
4142
ravel_multi_index,
4243
repeat,
4344
searchsorted,
4445
squeeze,
4546
to_one_hot,
47+
unpack,
4648
unravel_index,
4749
)
4850
from pytensor.tensor.type import (
@@ -1387,3 +1389,72 @@ def test_concat_with_broadcast():
13871389
a = pt.tensor("a", shape=(1, 3, 5))
13881390
b = pt.tensor("b", shape=(3, 5))
13891391
pt.concat_with_broadcast([a, b], axis=1)
1392+
1393+
1394+
@pytest.mark.parametrize(
1395+
"shapes, expected_flat_shape",
1396+
[([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)],
1397+
ids=["static", "symbolic"],
1398+
)
1399+
def test_pack(shapes, expected_flat_shape):
1400+
rng = np.random.default_rng()
1401+
1402+
x = pt.tensor("x", shape=shapes[0])
1403+
y = pt.tensor("y", shape=shapes[1])
1404+
z = pt.tensor("z", shape=shapes[2])
1405+
1406+
has_static_shape = [not any(s is None for s in shape) for shape in shapes]
1407+
1408+
flat_packed, packed_shapes = pack(x, y, z)
1409+
1410+
assert flat_packed.type.shape[0] == expected_flat_shape
1411+
1412+
for i, (packed_shape, has_static) in enumerate(
1413+
zip(packed_shapes, has_static_shape)
1414+
):
1415+
if has_static:
1416+
assert packed_shape == shapes[i]
1417+
else:
1418+
assert isinstance(packed_shape, Variable)
1419+
1420+
new_outputs = unpack(flat_packed, packed_shapes)
1421+
1422+
assert len(new_outputs) == 3
1423+
assert all(
1424+
out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z])
1425+
)
1426+
1427+
fn = function([x, y, z], new_outputs, mode="FAST_COMPILE")
1428+
1429+
input_vals = [
1430+
rng.normal(size=shape).astype(config.floatX)
1431+
for var, shape in zip([x, y, z], [(), (5,), (3, 3)])
1432+
]
1433+
new_output_vals = fn(*input_vals)
1434+
for input, output in zip(input_vals, new_output_vals):
1435+
np.testing.assert_allclose(input, output)
1436+
1437+
1438+
def test_make_replacements_with_pack_unpack():
1439+
rng = np.random.default_rng()
1440+
1441+
x = pt.tensor("x", shape=())
1442+
y = pt.tensor("y", shape=(5,))
1443+
z = pt.tensor("z", shape=(3, 3))
1444+
1445+
loss = (x + y.sum() + z.sum()) ** 2
1446+
1447+
flat_packed, packed_shapes = pack(x, y, z)
1448+
new_input = flat_packed.type()
1449+
new_outputs = unpack(new_input, packed_shapes)
1450+
1451+
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
1452+
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")
1453+
1454+
input_vals = [
1455+
rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z]
1456+
]
1457+
flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0)
1458+
output_val = fn(flat_inputs)
1459+
1460+
assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)

0 commit comments

Comments
 (0)