|
28 | 28 | from pytensor.scalar import upcast |
29 | 29 | from pytensor.tensor import TensorLike, as_tensor_variable |
30 | 30 | from pytensor.tensor import basic as ptb |
31 | | -from pytensor.tensor.basic import alloc, arange, join, second |
| 31 | +from pytensor.tensor.basic import alloc, join, second, split |
32 | 32 | from pytensor.tensor.exceptions import NotScalarConstantError |
33 | 33 | from pytensor.tensor.math import abs as pt_abs |
34 | 34 | from pytensor.tensor.math import all as pt_all |
|
47 | 47 | from pytensor.tensor.math import max as pt_max |
48 | 48 | from pytensor.tensor.math import sum as pt_sum |
49 | 49 | from pytensor.tensor.shape import Shape_i |
50 | | -from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor, take |
| 50 | +from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor |
51 | 51 | from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector |
52 | 52 | from pytensor.tensor.utils import normalize_reduce_axis |
53 | 53 | from pytensor.tensor.variable import TensorVariable |
@@ -2128,17 +2128,15 @@ def unpack( |
2128 | 2128 | if not packed_shapes: |
2129 | 2129 | raise ValueError("Cannot unpack an empty list of shapes.") |
2130 | 2130 |
|
2131 | | - start = 0 |
2132 | | - unpacked_tensors = [] |
2133 | | - for shape in packed_shapes: |
2134 | | - size = prod(shape, no_zeros_in_input=True) |
2135 | | - end = start + size |
2136 | | - unpacked_tensors.append( |
2137 | | - take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape) |
2138 | | - ) |
2139 | | - start = end |
| 2131 | + n_splits = len(packed_shapes) |
| 2132 | + split_size = [ |
| 2133 | + prod(shape, no_zeros_in_input=True).astype(int) for shape in packed_shapes |
| 2134 | + ] |
| 2135 | + unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits) |
2140 | 2136 |
|
2141 | | - return tuple(unpacked_tensors) |
| 2137 | + return tuple( |
| 2138 | + [x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)] |
| 2139 | + ) |
2142 | 2140 |
|
2143 | 2141 |
|
2144 | 2142 | __all__ = [ |
|
0 commit comments