Skip to content

Commit 89daa43

Browse files
committed
Make DimShuffle a regular COp
This is much faster to create/compile and produces better code for the most ubiquitous Op in PyTensor
1 parent 1db1f92 commit 89daa43

File tree

3 files changed

+125
-155
lines changed

3 files changed

+125
-155
lines changed

pytensor/link/c/op.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ def __init__(
331331
files overriding sections in previous files.
332332
333333
"""
334+
warnings.warn(
335+
"ExternalCOp is deprecated and will be removed in a future release. Use regular COp instead.",
336+
FutureWarning,
337+
)
334338
if not isinstance(func_files, list):
335339
self.func_files = [Path(func_files)]
336340
else:

pytensor/tensor/c_code/dimshuffle.c

Lines changed: 0 additions & 86 deletions
This file was deleted.

pytensor/tensor/elemwise.py

Lines changed: 121 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
1414
from pytensor.graph.utils import MethodNotDefined
1515
from pytensor.link.c.basic import failure_code
16-
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
17-
from pytensor.link.c.params_type import ParamsType
16+
from pytensor.link.c.op import COp, OpenMPOp
1817
from pytensor.misc.frozendict import frozendict
1918
from pytensor.npy_2_compat import normalize_axis_tuple
2019
from pytensor.printing import Printer, pprint
2120
from pytensor.scalar import get_scalar_type
2221
from pytensor.scalar.basic import identity as scalar_identity
23-
from pytensor.scalar.basic import int64, transfer_type, upcast
22+
from pytensor.scalar.basic import transfer_type, upcast
2423
from pytensor.tensor import elemwise_cgen as cgen
2524
from pytensor.tensor import get_vector_length
2625
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
@@ -29,7 +28,6 @@
2928
continuous_dtypes,
3029
discrete_dtypes,
3130
float_dtypes,
32-
lvector,
3331
)
3432
from pytensor.tensor.utils import (
3533
broadcast_static_dim_lengths,
@@ -40,7 +38,7 @@
4038
from pytensor.utils import uniq
4139

4240

43-
class DimShuffle(ExternalCOp):
41+
class DimShuffle(COp):
4442
"""
4543
Allows to reorder the dimensions of a tensor or insert or remove
4644
broadcastable dimensions.
@@ -114,74 +112,54 @@ class DimShuffle(ExternalCOp):
114112
_f16_ok = True
115113
check_input = False
116114
__props__ = ("input_ndim", "new_order")
117-
c_func_file = "c_code/dimshuffle.c"
118-
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
119115
view_map = {0: [0]}
120116

121-
@property
122-
def params_type(self):
123-
return ParamsType(
124-
_new_order=lvector,
125-
input_ndim=int64,
126-
)
127-
128117
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
129-
super().__init__([self.c_func_file], self.c_func_name)
130-
131118
if not isinstance(input_ndim, int):
132119
raise TypeError(f"input_ndim must be an integer, got {type(int)}")
133120

134121
self.input_ndim = input_ndim
135122
self.new_order = tuple(new_order)
136123
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]
137124

138-
for i, j in enumerate(new_order):
139-
if j != "x":
140-
if not isinstance(j, int | np.integer):
141-
raise TypeError(
142-
"DimShuffle indices must be Python ints; got "
143-
f"{j} of type {type(j)}."
144-
)
145-
if j >= input_ndim:
146-
raise ValueError(
147-
f"new_order[{i}] is {j}, but the input only has "
148-
f"{input_ndim} axes."
149-
)
150-
if j in new_order[(i + 1) :]:
151-
raise ValueError(
152-
"The same input dimension may not appear "
153-
f"twice in the list of output dimensions: {new_order}"
154-
)
155-
156125
# List of input dimensions to drop
157-
drop = [i for i in range(input_ndim) if i not in new_order]
126+
self.drop = drop = [i for i in range(input_ndim) if i not in new_order]
158127

159128
# This is the list of the original dimensions that we keep
160-
self.shuffle = [x for x in new_order if x != "x"]
161-
self.transposition = self.shuffle + drop
162-
# List of dimensions of the output that are broadcastable and were not
163-
# in the original input
164-
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
165-
self.drop = drop
129+
self.shuffle = shuffle = [x for x in new_order if x != "x"]
130+
131+
# Input validation
132+
if not all(isinstance(x, int | np.integer) for x in shuffle):
133+
raise TypeError(
134+
"DimShuffle indices must be Python ints; got "
135+
f"{shuffle} of type {[type(x) for x in shuffle]}."
136+
)
137+
if len(shuffle) != len(set(shuffle)):
138+
raise ValueError(
139+
f"Some dimensions were duplicated in new_order: {new_order}"
140+
)
141+
if max(shuffle, default=0) > input_ndim:
142+
raise ValueError(
143+
f"Some dimensions in new_order are too large for input_ndim {input_ndim}: {new_order}"
144+
)
166145

167-
dims_are_shuffled = sorted(self.shuffle) != self.shuffle
146+
self.transposition = self.shuffle + drop
147+
# List of expand_dims positions
148+
self.augment = augment = [i for i, x in enumerate(new_order) if x == "x"]
168149

150+
# Properties that are useful for rewrites
151+
self.dims_are_shuffled = dims_are_shuffled = sorted(shuffle) != shuffle
169152
self.is_transpose = dims_are_shuffled and not augment and not drop
170153
self.is_squeeze = drop and not dims_are_shuffled and not augment
171-
self.is_expand_dims = augment and not dims_are_shuffled and not drop
172-
self.is_left_expand_dims = self.is_expand_dims and (
154+
self.is_expand_dims = is_expand_dims = (
155+
augment and not dims_are_shuffled and not drop
156+
)
157+
self.is_left_expand_dims = is_expand_dims and (
173158
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
174159
)
175-
self.is_right_expand_dims = self.is_expand_dims and new_order[
176-
:input_ndim
177-
] == list(range(input_ndim))
178-
179-
def __setstate__(self, state):
180-
self.__dict__.update(state)
181-
if not hasattr(self, "func_files"):
182-
# Perhaps we are loading an old `Op` version of DimShuffle.
183-
# Let's just build the ExternalCOp.
184-
super().__init__([self.c_func_file], self.c_func_name)
160+
self.is_right_expand_dims = is_expand_dims and new_order[:input_ndim] == list(
161+
range(input_ndim)
162+
)
185163

186164
def make_node(self, inp):
187165
input = as_tensor_variable(inp)
@@ -193,22 +171,18 @@ def make_node(self, inp):
193171

194172
input_static_shape = input.type.shape
195173

196-
# Runtime check for invalid drop
197-
for d in self.drop:
198-
if input_static_shape[d] not in (1, None):
199-
raise TypeError(
200-
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
201-
)
202-
203-
out_static_shape = []
204-
for dim_idx in self.new_order:
205-
if dim_idx == "x":
206-
out_static_shape.append(1)
207-
else:
208-
out_static_shape.append(input_static_shape[dim_idx])
209-
210-
output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()
174+
# Check for invalid drop
175+
if self.drop:
176+
for d in self.drop:
177+
if input_static_shape[d] not in (1, None):
178+
raise TypeError(
179+
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
180+
)
211181

182+
output = TensorType(
183+
dtype=input.type.dtype,
184+
shape=[1 if d == "x" else input_static_shape[d] for d in self.new_order],
185+
)()
212186
return Apply(self, [input], [output])
213187

214188
def __str__(self):
@@ -273,6 +247,84 @@ def grad(self, inp, grads):
273247
else:
274248
return [gz.dimshuffle(grad_order)]
275249

250+
def c_code(self, node, name, input_names, output_names, sub):
251+
[inp] = input_names
252+
[out] = output_names
253+
nd_in = node.inputs[0].ndim
254+
nd_out = node.outputs[0].ndim
255+
drop = self.drop
256+
fail = sub["fail"]
257+
258+
code = f"npy_intp dimensions[{nd_out}];\n"
259+
code += f"npy_intp strides[{nd_out}];\n"
260+
261+
code += dedent(
262+
f"""
263+
if (PyArray_NDIM({inp}) != {nd_in}) {{
264+
PyErr_SetString(PyExc_ValueError, "ExpandDims: Input dimensions do not match expected.");
265+
{fail}
266+
}}
267+
"""
268+
)
269+
270+
if drop:
271+
code += "npy_intp new_size = 1;\n"
272+
for i, o in enumerate(self.new_order):
273+
if o == "x":
274+
code += f"dimensions[{i}] = 1;\n"
275+
code += f"strides[{i}] = PyArray_ITEMSIZE({inp});\n"
276+
else:
277+
code += f"dimensions[{i}] = PyArray_DIMS({inp})[{o}];\n"
278+
code += f"strides[{i}] = PyArray_DIMS({inp})[{o}] == 1 ? PyArray_ITEMSIZE({inp}) : PyArray_STRIDES({inp})[{o}];\n"
279+
if drop:
280+
code += f"new_size *= dimensions[{i}];\n"
281+
282+
if drop:
283+
code += dedent(
284+
f"""
285+
if (PyArray_SIZE({inp}) != new_size) {{
286+
PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one.");
287+
{fail}
288+
}}
289+
"""
290+
)
291+
292+
code += dedent(
293+
f"""
294+
Py_XDECREF({out});
295+
296+
Py_INCREF(PyArray_DESCR({inp}));
297+
{out} = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
298+
PyArray_DESCR({inp}),
299+
{nd_out}, dimensions,
300+
strides,
301+
PyArray_DATA({inp}),
302+
(PyArray_FLAGS({inp}) & ~NPY_ARRAY_OWNDATA),
303+
NULL);
304+
305+
if ({out} == NULL) {{
306+
{fail}
307+
}}
308+
309+
// Declare it a view of the original input
310+
Py_INCREF((PyObject*){inp});
311+
PyArray_SetBaseObject({out}, (PyObject*){inp});
312+
"""
313+
)
314+
315+
if self.dims_are_shuffled:
316+
code += dedent(
317+
f"""
318+
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
319+
PyArray_UpdateFlags({out}, NPY_ARRAY_UPDATE_ALL);
320+
"""
321+
)
322+
323+
return code
324+
325+
def c_code_cache_version(self):
326+
return (0,)
327+
276328

277329
class DimShufflePrinter(Printer):
278330
def __p(self, new_order, pstate, r):

0 commit comments

Comments
 (0)