From 9568a833362573b3665ecb6d184952c68d90613e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 10 Aug 2025 15:33:50 +0800 Subject: [PATCH 1/8] Implement pack/unpack helpers --- pytensor/tensor/extra_ops.py | 73 +++++++++++++++++++++++++++++++++- tests/tensor/test_extra_ops.py | 73 +++++++++++++++++++++++++++++++++- 2 files changed, 143 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 33a5a6b8dc..92cc310311 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -25,7 +25,7 @@ from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb -from pytensor.tensor.basic import alloc, join, second +from pytensor.tensor.basic import alloc, arange, join, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -45,7 +45,7 @@ from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor -from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes +from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -2011,6 +2011,73 @@ def concat_with_broadcast(tensor_list, axis=0): return join(axis, *bcast_tensor_inputs) +def pack( + *tensors: TensorVariable, +) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: + """ + Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. + + Parameters + ---------- + tensors: TensorVariable + Tensors to be packed into a single vector. + + Returns + ------- + flat_tensor: TensorVariable + A new symbolic variable representing the concatenated 1d vector of all tensor inputs + packed_shapes: list of tuples of TensorVariable + A list of tuples, where each tuple contains the symbolic shape of the original tensors. + """ + if not tensors: + raise ValueError("Cannot pack an empty list of tensors.") + + # Get the shapes of the input tensors + packed_shapes = [ + t.type.shape if not any(s is None for s in t.type.shape) else t.shape + for t in tensors + ] + + # Flatten each tensor and concatenate them into a single 1D vector + flat_tensor = join(0, *[t.ravel() for t in tensors]) + + return flat_tensor, packed_shapes + + +def unpack( + flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]] +) -> tuple[TensorVariable, ...]: + """ + Unpack a flat tensor into its original shapes based on the provided packed shapes. + + Parameters + ---------- + flat_tensor: TensorVariable + A 1D tensor that contains the concatenated values of the original tensors. + packed_shapes: list of tuples of TensorVariable + A list of tuples, where each tuple contains the symbolic shape of the original tensors. + + Returns + ------- + unpacked_tensors: tuple of TensorVariable + A tuple containing the unpacked tensors with their original shapes. + """ + if not packed_shapes: + raise ValueError("Cannot unpack an empty list of shapes.") + + start = 0 + unpacked_tensors = [] + for shape in packed_shapes: + size = prod(shape, no_zeros_in_input=True) + end = start + size + unpacked_tensors.append( + take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape) + ) + start = end + + return tuple(unpacked_tensors) + + __all__ = [ "bartlett", "bincount", @@ -2033,4 +2100,6 @@ def concat_with_broadcast(tensor_list, axis=0): "squeeze", "unique", "unravel_index", + "pack", + "unpack", ] diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 01de6cb517..4eca5b81d3 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph import rewrite_graph -from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.basic import Constant, Variable, equal_computations from pytensor.graph.traversal import applys_between from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert @@ -38,11 +38,13 @@ diff, fill_diagonal, fill_diagonal_offset, + pack, ravel_multi_index, repeat, searchsorted, squeeze, to_one_hot, + unpack, unravel_index, ) from pytensor.tensor.type import ( @@ -1387,3 +1389,72 @@ def test_concat_with_broadcast(): a = pt.tensor("a", shape=(1, 3, 5)) b = pt.tensor("b", shape=(3, 5)) pt.concat_with_broadcast([a, b], axis=1) + + +@pytest.mark.parametrize( + "shapes, expected_flat_shape", + [([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)], + ids=["static", "symbolic"], +) +def test_pack(shapes, expected_flat_shape): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=shapes[0]) + y = pt.tensor("y", shape=shapes[1]) + z = pt.tensor("z", shape=shapes[2]) + + has_static_shape = [not any(s is None for s in shape) for shape in shapes] + + flat_packed, packed_shapes = pack(x, y, z) + + assert flat_packed.type.shape[0] == expected_flat_shape + + for i, (packed_shape, has_static) in enumerate( + zip(packed_shapes, has_static_shape) + ): + if has_static: + assert packed_shape == shapes[i] + else: + assert isinstance(packed_shape, Variable) + + new_outputs = unpack(flat_packed, packed_shapes) + + assert len(new_outputs) == 3 + assert all( + out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z]) + ) + + fn = function([x, y, z], new_outputs, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=shape).astype(config.floatX) + for var, shape in zip([x, y, z], [(), (5,), (3, 3)]) + ] + new_output_vals = fn(*input_vals) + for input, output in zip(input_vals, new_output_vals): + np.testing.assert_allclose(input, output) + + +def test_make_replacements_with_pack_unpack(): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + loss = (x + y.sum() + z.sum()) ** 2 + + flat_packed, packed_shapes = pack(x, y, z) + new_input = flat_packed.type() + new_outputs = unpack(new_input, packed_shapes) + + loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) + fn = pytensor.function([new_input], loss, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z] + ] + flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0) + output_val = fn(flat_inputs) + + assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2) From 2e22d34f077c1e1031111d7adc503234d2976422 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 19 Sep 2025 19:02:23 -0500 Subject: [PATCH 2/8] Use split --- pytensor/tensor/extra_ops.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 92cc310311..31ba4f0cb0 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -25,7 +25,7 @@ from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb -from pytensor.tensor.basic import alloc, arange, join, second +from pytensor.tensor.basic import alloc, join, second, split from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -2065,17 +2065,15 @@ def unpack( if not packed_shapes: raise ValueError("Cannot unpack an empty list of shapes.") - start = 0 - unpacked_tensors = [] - for shape in packed_shapes: - size = prod(shape, no_zeros_in_input=True) - end = start + size - unpacked_tensors.append( - take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape) - ) - start = end + n_splits = len(packed_shapes) + split_size = [ + prod(shape, no_zeros_in_input=True).astype(int) for shape in packed_shapes + ] + unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits) - return tuple(unpacked_tensors) + return tuple( + [x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)] + ) __all__ = [ From 79d966274634bebc7d4acf09f754b6253612ffe3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 19 Sep 2025 19:03:30 -0500 Subject: [PATCH 3/8] Allow zero shapes --- pytensor/tensor/extra_ops.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 31ba4f0cb0..122078a478 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2066,9 +2066,7 @@ def unpack( raise ValueError("Cannot unpack an empty list of shapes.") n_splits = len(packed_shapes) - split_size = [ - prod(shape, no_zeros_in_input=True).astype(int) for shape in packed_shapes - ] + split_size = [prod(shape).astype(int) for shape in packed_shapes] unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits) return tuple( From 58c0286666ef66c41ac81b45f28c0c4cb99a9719 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 20 Sep 2025 12:10:59 -0400 Subject: [PATCH 4/8] Remove unnecessary comments --- pytensor/tensor/extra_ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 122078a478..64710a7093 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,5 +1,5 @@ import warnings -from collections.abc import Collection, Iterable +from collections.abc import Collection, Iterable, Sequence from textwrap import dedent import numpy as np @@ -2012,7 +2012,7 @@ def concat_with_broadcast(tensor_list, axis=0): def pack( - *tensors: TensorVariable, + *tensors: TensorVariable, axes: int | Sequence[int] | None = None ) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: """ Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. @@ -2021,6 +2021,9 @@ def pack( ---------- tensors: TensorVariable Tensors to be packed into a single vector. + axes: int or sequence of int, optional + Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled + and joined. Returns ------- @@ -2032,13 +2035,11 @@ def pack( if not tensors: raise ValueError("Cannot pack an empty list of tensors.") - # Get the shapes of the input tensors packed_shapes = [ t.type.shape if not any(s is None for s in t.type.shape) else t.shape for t in tensors ] - # Flatten each tensor and concatenate them into a single 1D vector flat_tensor = join(0, *[t.ravel() for t in tensors]) return flat_tensor, packed_shapes From 57883338d2b91de807537a4d5e2ad8397d7e8d2b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Oct 2025 14:34:58 -0500 Subject: [PATCH 5/8] Feature complete Pack Op --- pytensor/tensor/extra_ops.py | 298 +++++++++++++++++++++++++++++++-- tests/tensor/test_extra_ops.py | 194 ++++++++++++++++----- 2 files changed, 444 insertions(+), 48 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 64710a7093..640ee6754d 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,5 +1,6 @@ import warnings from collections.abc import Collection, Iterable, Sequence +from itertools import pairwise from textwrap import dedent import numpy as np @@ -45,7 +46,7 @@ from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor -from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector +from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -2011,6 +2012,287 @@ def concat_with_broadcast(tensor_list, axis=0): return join(axis, *bcast_tensor_inputs) +class Pack(Op): + __props__ = ("axes",) + + def __init__(self, axes: int | Sequence[int] | None): + self.axes = tuple(axes) if isinstance(axes, list) else axes + + def _analyze_axes_list(self) -> tuple[int, int, int, int | None]: + """ + Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as + well as the minimum and maximum number of axes that the inputs can have. + + The rules are: + - Axes must be strictly increasing in both the positive and negative parts of the list. + - Negative axes must come after positive axes. + - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint + (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]). + + Returns + ------- + n_axes_before: int + The number of axes before the interval to be raveled. + n_axes_after: int + The number of axes after the interval to be raveled. + min_axes: int + The minimum number of axes that the inputs must have. + max_axes: int or None + The maximum number of axes that the inputs can have, or None if there is no strict maximum. A maximum is + only introduced when it would resolve ambiguities in the interpretation of the axes list. For example, + [2, 3] can be either interpreted as having two ravel intervals [:2] and [4:], which is illegal, + unless 3 is interpreted as -1, which is only possible if all inputs have exactly 4 axes. Likewise, + [-3, -1] can be interpreted as having two ravel intervals [:-3], [-3:], unless -3 is interpreted as 0, + which is only possible if all inputs have exactly 3 axes. + """ + axes = self.axes + if axes is None: + return 0, 0, 0, None + + if isinstance(axes, int): + axes = [axes] + + if len(set(axes)) != len(axes): + raise ValueError("axes must have no duplicates") + if axes is not None and len(axes) == 0: + raise ValueError("axes=[] is ambiguous; use None to ravel all") + + first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes)) + positive_axes = list(axes[:first_negative_idx]) + negative_axes = list(axes[first_negative_idx:]) + + if not all(a < 0 for a in negative_axes): + raise ValueError("Negative axes must come after positive") + + def strictly_increasing(s): + return all(b > a for a, b in pairwise(s)) + + if (positive_axes and not strictly_increasing(positive_axes)) or ( + negative_axes and not strictly_increasing(negative_axes) + ): + raise ValueError("Axes must be strictly increasing") + + def find_gaps(s): + return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1] + + pos_gaps = find_gaps(positive_axes) + neg_gaps = find_gaps(negative_axes) + positive_only = positive_axes and not negative_axes + negative_only = negative_axes and not positive_axes + mixed_case = positive_axes and negative_axes + + max_axes: int | None = None + + n_explicit_holes = len(pos_gaps) + len(neg_gaps) + if n_explicit_holes > 1: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis." + ) + + if mixed_case: + if pos_gaps or neg_gaps: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis. Because both " + "positive and negative axes are present, there is always assume to be an explit hole " + "between them." + ) + n_before = len(positive_axes) + n_after = len(negative_axes) + min_axes = n_before + n_after + + if positive_only: + # There are four cases to consider when all axes are positive: + # 0. There are two implicit gaps (0 is not present) and an explicit gap (e.g. [2, 4]) + # This case is always illegal, as there is no interpretation that would result in having + # 1. There is only an implicit right hole (e.g. [0, 1]) + # This case is legal, and requires no special interpretation. It corresponds to 'i j *' in einops + # 2. There is an explicit internal hole (e.g. [0, 2]) + # This case is legal, but requires interpreting the last axis as -1, which introduces a maximum number + # of axes. It corresponds to 'i * j' in einops, and requires at least one input to have 3 dimensions, and + # no input to have more than 3 dimensions. + # 2. The axes start at an index greater than 0, but have no internal holes (e.g. [2, 3]) + # This case is legal, but requires flipping the axes to negative indexing, so that the largest axis is + # -1, followed by -2, etc. This introduces a maximum number of axes. + if pos_gaps and positive_axes[0] != 0: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis. In this case, " + "there is an explicit internal hole as well as an implicit left hole." + ) + + elif positive_axes[0] == 0 and not pos_gaps: + # Case 1: Only right implicit hole. No ambiguities. + n_before = positive_axes[-1] + 1 + n_after = 0 + min_axes = n_before + n_after + max_axes = None + + elif pos_gaps: + # Case 2: Explicit hole in the positives, plus right implicit hole. + split = pos_gaps[0] + 1 + n_before = split + n_after = len(positive_axes) - split + min_axes = n_before + n_after + + # Close the right implicit hole + max_axes = positive_axes[-1] + 1 + + else: + # Case 3: Left and right implicit holes, but the right can be closed by flipping to negative axes and + # adding a maximum number of axes. + # Compute min_axes and max_axes under Case 1 of the negative_only scenario, with a max_axes constraint. + max_axes = positive_axes[-1] + 1 + n_before = 0 + n_after = len(positive_axes) + min_axes = n_before + n_after + + if negative_only: + # The same four cases are considered when all axes are negative, but ordering is reversed. + # 0. There are two implicit holes (e.g. [-4, -2]) + # This case is always illegal, as there is no interpretation that would result in having only one hole + # in the axis list. + # 1. There is only an implicit left hole (e.g. [-2, -1]) + # This case is legal, and requires no special interpretation. It corresponds to '* i j' in einops + # 2. There is an explicit internal hole (e.g. [-3, -1]) + # This case is legal, but requires interpreting the smallest axis as 0, which introduces a maximum number + # of axes. It corresponds to '* i j' in einops, and requires at least one input to have 3 dimensions, and + # no input to have more than 3 dimensions. + # 3. The axes end at an index less than -1, but have no internal holes (e.g. [-4, -3]). Flip to positive + # axes, adding a maximum number of axes. Interpret the smallest axis as 0 to resolve ambiguity. + if neg_gaps and negative_axes[-1] != -1: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis. In this case, " + "there is an explicit internal hole as well as an implicit right hole." + ) + elif negative_axes[-1] == -1 and not neg_gaps: + # Case 1: No ambiguities, only left implicit hole. + n_before = 0 + n_after = len(negative_axes) + min_axes = n_before + n_after + max_axes = None + elif neg_gaps: + # Case 2: Explicit hole in the negatives, plus left implicit hole. + split = neg_gaps[0] + 1 + n_before = split + n_after = len(negative_axes) - split + min_axes = n_before + n_after + + # Close the left implicit hole + max_axes = abs(min(negative_axes)) + else: + # Case 3: Left and right implicit holes, but the left can be closed by flipping to positive axes and + # adding a maximum number of axes. + max_axes = abs(negative_axes[0]) + n_before = negative_axes[-1] + max_axes + 1 + n_after = 0 + min_axes = n_before + n_after + + return n_before, n_after, min_axes, max_axes + + def make_node(self, *tensors: TensorVariable): + tensors = [ptb.as_tensor_variable(t) for t in tensors] + n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list() + + if min([t.ndim for t in tensors]) < min_axes: + raise ValueError( + f"All input tensors to {self!s} must have at least {min_axes} dimensions, but the minimum " + f"number of dimensions found was {min([t.ndim for t in tensors])}." + ) + + max_ndim = max([t.ndim for t in tensors]) + if max_axes is not None and max_ndim > max_axes: + raise ValueError( + f"All input tensors to {self!s} must have at most {max_axes} dimensions, but the maximum " + f"number of dimensions found was {max_ndim}." + ) + + def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None: + unique_shapes = {s for s in shapes if s is not None} + if not unique_shapes: + return None + if len(unique_shapes) > 1: + raise ValueError( + f"Input tensors to Pack op have incompatible sizes on dimension {axis} : {shapes}" + ) + return unique_shapes.pop() + + shapes_to_pack = [ + t.type.shape[n_axes_before : t.ndim - n_axes_after] for t in tensors + ] + packed_shape = ( + None + if any( + shape is None + for packed_shape in shapes_to_pack + for shape in packed_shape + ) + else int(sum(np.prod(shapes) for shapes in shapes_to_pack)) + ) + prefix_shapes = [ + _coalesce_dim([t.type.shape[i] for t in tensors], i) + for i in range(n_axes_before) + ] + suffix_shapes = [ + _coalesce_dim( + [t.type.shape[t.ndim - n_axes_after + i] for t in tensors], + n_axes_before + i, + ) + for i in range(n_axes_after) + ] + out_shape = (*prefix_shapes, packed_shape, *suffix_shapes) + + packed_output = ptb.tensor(dtype=tensors[0].dtype, shape=out_shape) + packed_shapes = [ + ptb.tensor(dtype="int64", shape=(len(shapes),)) for shapes in shapes_to_pack + ] + + return Apply(self, tensors, [packed_output, *packed_shapes]) + + def perform(self, node, inputs, outputs): + tensors = inputs + packed_output, *packed_shapes = outputs + + reshaped_tensors = [] + tmp_shapes = [] + + n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list() + + if ( + max_axes is not None + and any(t.ndim > max_axes for t in tensors) + and not any(t.ndim == max_axes for t in tensors) + ): + raise ValueError( + f"All input tensors must have at most {max_axes} axes, and at least one input tensor must have exactly " + f"{max_axes} axes to resolve ambiguities in the interpretation of the axes list {self.axes}. A less" + f"ambiguous axes list can be used to avoid this restriction, usually by including 0 or -1 in the axes " + f"list." + ) + + for i, tensor in enumerate(tensors): + shape = tensor.shape + ndim = tensor.ndim + if tensor.ndim < min_axes: + raise ValueError( + f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, " + f"while pattern {self.axes} assumes at least {min_axes} axes" + ) + axis_after_packed_axes = ndim - n_axes_after + tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes]) + reshaped_tensors.append( + tensor.reshape( + (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]) + ) + ) + + packed_output[0] = np.concatenate(reshaped_tensors, axis=n_axes_before) + for i, packed_shape in enumerate(tmp_shapes): + packed_shapes[i][0] = np.array(packed_shape).astype("int64") + + def pack( *tensors: TensorVariable, axes: int | Sequence[int] | None = None ) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: @@ -2035,14 +2317,10 @@ def pack( if not tensors: raise ValueError("Cannot pack an empty list of tensors.") - packed_shapes = [ - t.type.shape if not any(s is None for s in t.type.shape) else t.shape - for t in tensors - ] - - flat_tensor = join(0, *[t.ravel() for t in tensors]) + pack_op = Pack(axes=axes) + packed_tensor, *packed_shapes = pack_op(*tensors) - return flat_tensor, packed_shapes + return packed_tensor, packed_shapes def unpack( @@ -2091,12 +2369,12 @@ def unpack( "geomspace", "linspace", "logspace", + "pack", "ravel_multi_index", "repeat", "searchsorted", "squeeze", "unique", - "unravel_index", - "pack", "unpack", + "unravel_index", ] diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 4eca5b81d3..aac897a596 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph import rewrite_graph -from pytensor.graph.basic import Constant, Variable, equal_computations +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.traversal import applys_between from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert @@ -21,6 +21,7 @@ CumOp, FillDiagonal, FillDiagonalOffset, + Pack, RavelMultiIndex, Repeat, SearchsortedOp, @@ -1391,48 +1392,165 @@ def test_concat_with_broadcast(): pt.concat_with_broadcast([a, b], axis=1) -@pytest.mark.parametrize( - "shapes, expected_flat_shape", - [([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)], - ids=["static", "symbolic"], -) -def test_pack(shapes, expected_flat_shape): - rng = np.random.default_rng() - - x = pt.tensor("x", shape=shapes[0]) - y = pt.tensor("y", shape=shapes[1]) - z = pt.tensor("z", shape=shapes[2]) +class TestPack: + @pytest.mark.parametrize( + "axes, expected", + [ + ([0, 1], [2, 0, 2, None]), # 'i j *' + ([-1], [0, 1, 1, None]), # '* k' + ([0, 1, 3], [2, 1, 3, 4]), # 'i j * k' + ([-3, -1], [1, 1, 2, 3]), # '* i j' + ([2, 3], [0, 2, 2, 4]), # '* i j' + ([-3, -2], [2, 0, 2, 3]), # 'i j *' + ([0, -1], [1, 1, 2, None]), # 'i * k' + ([0, 1, 2, -1], [3, 1, 4, None]), # 'i j k * l' + ([0, 1, 4], [2, 1, 3, 5]), + ([-4, -1], [1, 1, 2, 4]), + ], + ids=[ + "basic", + "keep_last", + "ravel_middle_implicit_end", + "implicit_start", + "ravel_start", + "implicit_end", + "mix_pos_neg", + "ravel_middle_explicit_end", + "pos_internal_bigger_gap", + "neg_internal_bigger_gap", + ], + ) + def test_analyze_axes_list_valid(self, axes, expected): + op = Pack(axes) + outputs = op._analyze_axes_list() + names = ["n_before", "n_after", "min_axes", "max_axes"] + for out, exp, name in zip(outputs, expected, names, strict=True): + assert out == exp, f"Expected {exp}, got {out} for {name}" + + def test_analyze_axes_list_invalid(self): + # Two explicit holes + op = Pack([0, 2, -1]) + with pytest.raises(ValueError, match="Too many holes"): + op._analyze_axes_list() + + # Explict hole + two implicit holes + op = Pack([1, 3]) + with pytest.raises(ValueError, match="Too many holes"): + op._analyze_axes_list() + + # Two explicit holes, all positive + op = Pack([0, 2, 4]) + with pytest.raises(ValueError, match="Too many holes"): + op._analyze_axes_list() + + # Explicit hole + two implicit hole, all negative + op = Pack([-4, -2]) + with pytest.raises(ValueError, match="Too many holes"): + op._analyze_axes_list() + + # Two explicit holes + implicit hole, all negative + op = Pack([-5, -3, -1]) + with pytest.raises(ValueError, match="Too many holes"): + op._analyze_axes_list() + + # Duplicate axes + op = Pack([0, 0]) + with pytest.raises(ValueError, match="axes must have no duplicates"): + op._analyze_axes_list() + + # Not monotonic + op = Pack([0, 2, 1]) + with pytest.raises(ValueError, match="Axes must be strictly increasing"): + op._analyze_axes_list() + + # Negative before positive + op = Pack([-1, 0]) + with pytest.raises(ValueError, match="Negative axes must come after positive"): + op._analyze_axes_list() + + def test_pack_basic(self): + # rng = np.random.default_rng() + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]} + + # Simple case, reduce all axes, equivalent to einops '*' + packed_tensor, packed_shapes = pack(x, y, z, axes=None) + assert packed_tensor.type.shape == (15,) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim,) + np.testing.assert_allclose(packed_shape.eval(input_dict), tensor.type.shape) + + # To preserve an axis, all inputs need at least one dimension, and the preserved axis has to agree. + # x is scalar, so pack will raise: + with pytest.raises( + ValueError, + match=r"All input tensors to Pack{axes=0} must have at least 1 dimensions", + ): + pack(x, y, z, axes=0) - has_static_shape = [not any(s is None for s in shape) for shape in shapes] + # With valid x, pack should still raise, because the axis of concatenation doesn't agree across all inputs + x = pt.tensor("x", shape=(3,)) + with pytest.raises( + ValueError, + match=r"Input tensors to Pack op have incompatible sizes on dimension 0 : " + r"\[3, 5, 3\]", + ): + pack(x, y, z, axes=0) + + # Valid case, preserve first axis, equivalent to einops 'i *' + y = pt.tensor("y", shape=(3, 5)) + z = pt.tensor("z", shape=(3, 3, 3)) + packed_tensor, packed_shapes = pack(x, y, z, axes=0) + input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]} + assert packed_tensor.type.shape == (3, 15) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim - 1,) + np.testing.assert_allclose( + packed_shape.eval(input_dict), tensor.type.shape[1:] + ) + # More complex case, preserve last axis implicitly, equivalent to einops 'i * k'. This introduces a max + # dimension condition on the input shapes + x = pt.tensor("x", shape=(3, 2)) + y = pt.tensor("y", shape=(3, 5, 2)) + z = pt.tensor("z", shape=(3, 1, 7, 5, 2)) - flat_packed, packed_shapes = pack(x, y, z) + with pytest.raises( + ValueError, + match=r"All input tensors to Pack{axes=\(0, 3\)} must have at most 4 " + r"dimensions, but the maximum number of dimensions found was 5", + ): + pack(x, y, z, axes=[0, 3]) + + z = pt.tensor("z", shape=(3, 1, 7, 2)) + packed_tensor, packed_shapes = pack(x, y, z, axes=[0, 3]) + input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]} + assert packed_tensor.type.shape == (3, 13, 2) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim - 2,) + np.testing.assert_allclose( + packed_shape.eval(input_dict), tensor.type.shape[1:-1] + ) - assert flat_packed.type.shape[0] == expected_flat_shape + def test_pack_unpack_round_trip(self): + rng = np.random.default_rng() - for i, (packed_shape, has_static) in enumerate( - zip(packed_shapes, has_static_shape) - ): - if has_static: - assert packed_shape == shapes[i] - else: - assert isinstance(packed_shape, Variable) + x = pt.tensor("x", shape=(5,)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=()) - new_outputs = unpack(flat_packed, packed_shapes) + flat_packed, packed_shapes = pack(x, y, z, axes=None) + new_outputs = unpack(flat_packed, packed_shapes) - assert len(new_outputs) == 3 - assert all( - out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z]) - ) + fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE") - fn = function([x, y, z], new_outputs, mode="FAST_COMPILE") + input_vals = [rng.normal(size=var.type.shape) for var in [x, y, z]] + output_vals = fn(*input_vals) - input_vals = [ - rng.normal(size=shape).astype(config.floatX) - for var, shape in zip([x, y, z], [(), (5,), (3, 3)]) - ] - new_output_vals = fn(*input_vals) - for input, output in zip(input_vals, new_output_vals): - np.testing.assert_allclose(input, output) + for input_val, output_val in zip(input_vals, output_vals, strict=True): + np.testing.assert_allclose(input_val, output_val) def test_make_replacements_with_pack_unpack(): @@ -1444,17 +1562,17 @@ def test_make_replacements_with_pack_unpack(): loss = (x + y.sum() + z.sum()) ** 2 - flat_packed, packed_shapes = pack(x, y, z) + flat_packed, packed_shapes = pack(x, y, z, axes=None) new_input = flat_packed.type() new_outputs = unpack(new_input, packed_shapes) loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) - fn = pytensor.function([new_input], loss, mode="FAST_COMPILE") + fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE") input_vals = [ rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z] ] flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0) - output_val = fn(flat_inputs) + output_val = fn(flat_inputs, *input_vals) assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2) From ed60651f599fad343f066f31256dd6b4a3294b40 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 2 Nov 2025 01:06:02 -0500 Subject: [PATCH 6/8] Implement Pack as OpFromGraph --- pytensor/tensor/extra_ops.py | 115 ++++++++++++++++----------------- tests/tensor/test_extra_ops.py | 38 +++++------ 2 files changed, 76 insertions(+), 77 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 640ee6754d..c5ae79635d 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -8,6 +8,7 @@ import pytensor import pytensor.scalar.basic as ps +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import ( DisconnectedType, _float_zeros_like, @@ -44,7 +45,7 @@ ) from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.shape import Shape_i +from pytensor.tensor.shape import Shape_i, specify_shape from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes from pytensor.tensor.utils import normalize_reduce_axis @@ -2012,11 +2013,10 @@ def concat_with_broadcast(tensor_list, axis=0): return join(axis, *bcast_tensor_inputs) -class Pack(Op): - __props__ = ("axes",) - +class PackHelper: def __init__(self, axes: int | Sequence[int] | None): self.axes = tuple(axes) if isinstance(axes, list) else axes + self.op_name = "Pack{axes=" + str(self.axes) + "}" def _analyze_axes_list(self) -> tuple[int, int, int, int | None]: """ @@ -2192,23 +2192,31 @@ def find_gaps(s): return n_before, n_after, min_axes, max_axes - def make_node(self, *tensors: TensorVariable): + def validate_inputs(self, tensors: list[TensorLike]): tensors = [ptb.as_tensor_variable(t) for t in tensors] - n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list() + _, _, min_axes, max_axes = self._analyze_axes_list() if min([t.ndim for t in tensors]) < min_axes: raise ValueError( - f"All input tensors to {self!s} must have at least {min_axes} dimensions, but the minimum " + f"All input tensors to {self.op_name} must have at least {min_axes} dimensions, but the minimum " f"number of dimensions found was {min([t.ndim for t in tensors])}." ) max_ndim = max([t.ndim for t in tensors]) - if max_axes is not None and max_ndim > max_axes: + if ( + max_axes is not None + and max_ndim > max_axes + and not any(t.ndim == max_axes for t in tensors) + ): raise ValueError( - f"All input tensors to {self!s} must have at most {max_axes} dimensions, but the maximum " + f"All input tensors to {self.op_name} must have at most {max_axes} dimensions, but the maximum " f"number of dimensions found was {max_ndim}." ) + def infer_shape(self, tensors: list[TensorLike]) -> tuple[int | None, ...]: + tensors = [ptb.as_tensor_variable(t) for t in tensors] + n_axes_before, n_axes_after, _, _ = self._analyze_axes_list() + def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None: unique_shapes = {s for s in shapes if s is not None} if not unique_shapes: @@ -2242,55 +2250,12 @@ def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None: ) for i in range(n_axes_after) ] - out_shape = (*prefix_shapes, packed_shape, *suffix_shapes) - - packed_output = ptb.tensor(dtype=tensors[0].dtype, shape=out_shape) - packed_shapes = [ - ptb.tensor(dtype="int64", shape=(len(shapes),)) for shapes in shapes_to_pack - ] - - return Apply(self, tensors, [packed_output, *packed_shapes]) - - def perform(self, node, inputs, outputs): - tensors = inputs - packed_output, *packed_shapes = outputs - - reshaped_tensors = [] - tmp_shapes = [] - n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list() - - if ( - max_axes is not None - and any(t.ndim > max_axes for t in tensors) - and not any(t.ndim == max_axes for t in tensors) - ): - raise ValueError( - f"All input tensors must have at most {max_axes} axes, and at least one input tensor must have exactly " - f"{max_axes} axes to resolve ambiguities in the interpretation of the axes list {self.axes}. A less" - f"ambiguous axes list can be used to avoid this restriction, usually by including 0 or -1 in the axes " - f"list." - ) + return (*prefix_shapes, packed_shape, *suffix_shapes) - for i, tensor in enumerate(tensors): - shape = tensor.shape - ndim = tensor.ndim - if tensor.ndim < min_axes: - raise ValueError( - f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, " - f"while pattern {self.axes} assumes at least {min_axes} axes" - ) - axis_after_packed_axes = ndim - n_axes_after - tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes]) - reshaped_tensors.append( - tensor.reshape( - (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]) - ) - ) - packed_output[0] = np.concatenate(reshaped_tensors, axis=n_axes_before) - for i, packed_shape in enumerate(tmp_shapes): - packed_shapes[i][0] = np.array(packed_shape).astype("int64") +class Pack(OpFromGraph): + "Wrapper for the Pack Op" def pack( @@ -2317,10 +2282,44 @@ def pack( if not tensors: raise ValueError("Cannot pack an empty list of tensors.") - pack_op = Pack(axes=axes) - packed_tensor, *packed_shapes = pack_op(*tensors) + tensors = [ptb.as_tensor(tensor) for tensor in tensors] + + pack_helper = PackHelper(axes=axes) + + reshaped_tensors = [] + tmp_shapes = [] + + n_axes_before, n_axes_after, _, _ = pack_helper._analyze_axes_list() + pack_helper.validate_inputs(tensors) + output_shape = pack_helper.infer_shape(tensors) + + for i, tensor in enumerate(tensors): + shape = tensor.shape + ndim = tensor.ndim + axis_after_packed_axes = ndim - n_axes_after + tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes]) + reshaped_tensors.append( + tensor.reshape( + (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]) + ) + ) + + packed_output_tensor = specify_shape( + ptb.join(n_axes_before, *reshaped_tensors), output_shape + ) + packed_output_shapes = [ + ptb.as_tensor_variable(packed_shape).astype("int64") + for i, packed_shape in enumerate(tmp_shapes) + ] + + pack_op = Pack( + inputs=tensors, + outputs=[packed_output_tensor, *packed_output_shapes], + name="Pack{axes=" + str(axes) + "}", + ) - return packed_tensor, packed_shapes + outputs = pack_op(*tensors) + return outputs[0], outputs[1:] def unpack( diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index aac897a596..a8b151dd70 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -21,7 +21,7 @@ CumOp, FillDiagonal, FillDiagonalOffset, - Pack, + PackHelper, RavelMultiIndex, Repeat, SearchsortedOp, @@ -1421,52 +1421,52 @@ class TestPack: ], ) def test_analyze_axes_list_valid(self, axes, expected): - op = Pack(axes) - outputs = op._analyze_axes_list() + helper = PackHelper(axes) + outputs = helper._analyze_axes_list() names = ["n_before", "n_after", "min_axes", "max_axes"] for out, exp, name in zip(outputs, expected, names, strict=True): assert out == exp, f"Expected {exp}, got {out} for {name}" def test_analyze_axes_list_invalid(self): # Two explicit holes - op = Pack([0, 2, -1]) + helper = PackHelper([0, 2, -1]) with pytest.raises(ValueError, match="Too many holes"): - op._analyze_axes_list() + helper._analyze_axes_list() # Explict hole + two implicit holes - op = Pack([1, 3]) + helper = PackHelper([1, 3]) with pytest.raises(ValueError, match="Too many holes"): - op._analyze_axes_list() + helper._analyze_axes_list() # Two explicit holes, all positive - op = Pack([0, 2, 4]) + helper = PackHelper([0, 2, 4]) with pytest.raises(ValueError, match="Too many holes"): - op._analyze_axes_list() + helper._analyze_axes_list() # Explicit hole + two implicit hole, all negative - op = Pack([-4, -2]) + helper = PackHelper([-4, -2]) with pytest.raises(ValueError, match="Too many holes"): - op._analyze_axes_list() + helper._analyze_axes_list() # Two explicit holes + implicit hole, all negative - op = Pack([-5, -3, -1]) + helper = PackHelper([-5, -3, -1]) with pytest.raises(ValueError, match="Too many holes"): - op._analyze_axes_list() + helper._analyze_axes_list() # Duplicate axes - op = Pack([0, 0]) + helper = PackHelper([0, 0]) with pytest.raises(ValueError, match="axes must have no duplicates"): - op._analyze_axes_list() + helper._analyze_axes_list() # Not monotonic - op = Pack([0, 2, 1]) + helper = PackHelper([0, 2, 1]) with pytest.raises(ValueError, match="Axes must be strictly increasing"): - op._analyze_axes_list() + helper._analyze_axes_list() # Negative before positive - op = Pack([-1, 0]) + helper = PackHelper([-1, 0]) with pytest.raises(ValueError, match="Negative axes must come after positive"): - op._analyze_axes_list() + helper._analyze_axes_list() def test_pack_basic(self): # rng = np.random.default_rng() From 0b86851235f14792614311fcb184b036864f64d3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 2 Nov 2025 01:08:12 -0500 Subject: [PATCH 7/8] float32 in tests --- tests/tensor/test_extra_ops.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index a8b151dd70..250142bf7f 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1474,7 +1474,10 @@ def test_pack_basic(self): y = pt.tensor("y", shape=(5,)) z = pt.tensor("z", shape=(3, 3)) - input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]} + input_dict = { + variable: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } # Simple case, reduce all axes, equivalent to einops '*' packed_tensor, packed_shapes = pack(x, y, z, axes=None) @@ -1504,7 +1507,10 @@ def test_pack_basic(self): y = pt.tensor("y", shape=(3, 5)) z = pt.tensor("z", shape=(3, 3, 3)) packed_tensor, packed_shapes = pack(x, y, z, axes=0) - input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]} + input_dict = { + variable: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } assert packed_tensor.type.shape == (3, 15) for tensor, packed_shape in zip([x, y, z], packed_shapes): assert packed_shape.type.shape == (tensor.ndim - 1,) @@ -1526,7 +1532,10 @@ def test_pack_basic(self): z = pt.tensor("z", shape=(3, 1, 7, 2)) packed_tensor, packed_shapes = pack(x, y, z, axes=[0, 3]) - input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]} + input_dict = { + variable: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } assert packed_tensor.type.shape == (3, 13, 2) for tensor, packed_shape in zip([x, y, z], packed_shapes): assert packed_shape.type.shape == (tensor.ndim - 2,) @@ -1546,7 +1555,9 @@ def test_pack_unpack_round_trip(self): fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE") - input_vals = [rng.normal(size=var.type.shape) for var in [x, y, z]] + input_vals = [ + rng.normal(size=var.type.shape).astype(config.floatX) for var in [x, y, z] + ] output_vals = fn(*input_vals) for input_val, output_val in zip(input_vals, output_vals, strict=True): From 20ab4e33214b08b46fe0df2b4fb21d37155aa0b3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 2 Nov 2025 08:44:56 -0600 Subject: [PATCH 8/8] Docs --- pytensor/tensor/extra_ops.py | 103 ++++++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index c5ae79635d..07868d4fcd 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2267,10 +2267,13 @@ def pack( Parameters ---------- tensors: TensorVariable - Tensors to be packed into a single vector. + Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each + of the dimensions specified in the `axes` parameter. axes: int or sequence of int, optional - Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled - and joined. + Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating + on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or + negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes + must come after positive axes. Returns ------- @@ -2278,6 +2281,99 @@ def pack( A new symbolic variable representing the concatenated 1d vector of all tensor inputs packed_shapes: list of tuples of TensorVariable A list of tuples, where each tuple contains the symbolic shape of the original tensors. + + Notes + ----- + This function is a helper for joining tensors of varying shapes into a single tenor. This is done by choosing a + list of axes to concatenate, and raveling all other axes. The resulting tensor are then concatenated along the + raveled axis. The original shapes of the tensors are also returned, so that they can be unpacked later. + + The `axes` parameter determines which dimensions are *not* raveled. The requested axes must exist in all input + tensors, but there are otherwwise no restrictions on the shapes or dimensions of the input tensors. For example, if + `axes=[0]`, then the first dimension of each tensor is preserved, and all other dimensions are raveled: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3, 4)) + y = pt.tensor("y", shape=(2, 5)) + packed_output, shapes = pack(x, y, axes=0) + # packed_output will have shape (2, 3 * 4 + 5) = (2, 17) + + Since axes = 0, the first dimension of both `x` and `y` is preserved. This first example is equivalent to a simple + reshape and concat operation: + + .. code-block:: python + + x_reshaped = x.reshape(2, -1) # shape (2, 12) + y_reshaped = y.reshape(2, -1) # shape (2, 5) + packed_output = pt.concatenate( + [x_reshaped, y_reshaped], axis=1 + ) # shape (2, 17) + + `axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example, + if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(3, 4, 7)) + y = pt.tensor("y", shape=(6, 2, 1, 7)) + packed_output, shapes = pack(x, y, axes=-1) + # packed_output will have shape (3 * 4 + 6 * 2 * 1, 7) = (24, 7) + + The most important restriction of `axes` is that there can be at most one "hole" in the axes list. A hole is + defined as a missing axis in the sequence of axes. The easiest way to define a hole is by using both positive + and negative axes together. For example, `axes=[0, -1]` has a hole between the first and last axes. In this case, + the first and last dimensions of each tensor are preserved, and all other dimensions are raveled: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3, 2, 3, 7)) + y = pt.tensor("y", shape=(2, 6, 7)) + packed_output, shapes = pack(x, y, axes=[0, -1]) + # packed_output will have shape (2, 3 * 2 * 3 + 6, 7) = (2, 24, 7) + + Multiple explicit holes are not allowed. For example, `axes = [0, 2, -1]` is illegal because there are two holes, + one between axes 0 and 2, and another between axes 2 and -1. + + Implicit holes are also possible when using only positive or only negative axes. `axes = [0]` already has an + implicit hole to the right of axis 0. `axes = [2, 3]` has two implicit holes, one to the left of axis 2, and another + to the right. This is illegal, since there are two holes. However, `axes = [2, 3]` can be made legal if we interpret + axis 3 as the last axis (-1), which closes the right implicit hole. The interpretation requires that at least one + input tensor has exactly 4 dimensions: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(5, 2, 3, 4)) + y = pt.tensor("y", shape=(2, 3, 4)) + packed_output, shapes = pack(x, y, axes=[2, 3]) + # packed_output will have shape (5 * 2 + 2, 3, 4) = (12, 3, 4) + + Note here that `y` has only 3 dimensions, so axis 3 is interpreted as -1, the last axis. If no input has 4 + dimensions, or if any input has more than 4 dimensions, an error is raised in this case. + + Negative axes have similar rules regarding implicit holes. `axes = [-1]` has an implicit hole to the left of + axis -1. `axes = [-3, -2]` has two implicit holes. To arrive at a valid interpretation, we take -3 to be axis 0, + which closes the left implicit hole. This requires that at least one input tensor has exactly 3 dimensions: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3, 4)) + y = pt.tensor("y", shape=(6, 4)) + packed_output, shapes = pack(x, y, axes=[-3, -2]) + # packed_output will have shape (2 + 6, 3, 4) = (8, 3, 4) + + Similarly to the previous example, if no input has 3 dimensions, or if any input has more than 3 dimensions, an + error would be raised in this example. """ if not tensors: raise ValueError("Cannot pack an empty list of tensors.") @@ -2316,6 +2412,7 @@ def pack( inputs=tensors, outputs=[packed_output_tensor, *packed_output_shapes], name="Pack{axes=" + str(axes) + "}", + inline=True, ) outputs = pack_op(*tensors)