Skip to content

Commit 8e4d639

Browse files
Use split
1 parent cd5e00d commit 8e4d639

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytensor.scalar import upcast
2929
from pytensor.tensor import TensorLike, as_tensor_variable
3030
from pytensor.tensor import basic as ptb
31-
from pytensor.tensor.basic import alloc, arange, join, second
31+
from pytensor.tensor.basic import alloc, join, second, split
3232
from pytensor.tensor.exceptions import NotScalarConstantError
3333
from pytensor.tensor.math import abs as pt_abs
3434
from pytensor.tensor.math import all as pt_all
@@ -47,7 +47,7 @@
4747
from pytensor.tensor.math import max as pt_max
4848
from pytensor.tensor.math import sum as pt_sum
4949
from pytensor.tensor.shape import Shape_i
50-
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor, take
50+
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
5151
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
5252
from pytensor.tensor.utils import normalize_reduce_axis
5353
from pytensor.tensor.variable import TensorVariable
@@ -2128,17 +2128,15 @@ def unpack(
21282128
if not packed_shapes:
21292129
raise ValueError("Cannot unpack an empty list of shapes.")
21302130

2131-
start = 0
2132-
unpacked_tensors = []
2133-
for shape in packed_shapes:
2134-
size = prod(shape, no_zeros_in_input=True)
2135-
end = start + size
2136-
unpacked_tensors.append(
2137-
take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape)
2138-
)
2139-
start = end
2131+
n_splits = len(packed_shapes)
2132+
split_size = [
2133+
prod(shape, no_zeros_in_input=True).astype(int) for shape in packed_shapes
2134+
]
2135+
unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits)
21402136

2141-
return tuple(unpacked_tensors)
2137+
return tuple(
2138+
[x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)]
2139+
)
21422140

21432141

21442142
__all__ = [

0 commit comments

Comments
 (0)