1313from pytensor .graph .replace import _vectorize_node , _vectorize_not_needed
1414from pytensor .graph .utils import MethodNotDefined
1515from pytensor .link .c .basic import failure_code
16- from pytensor .link .c .op import COp , ExternalCOp , OpenMPOp
17- from pytensor .link .c .params_type import ParamsType
16+ from pytensor .link .c .op import COp , OpenMPOp
1817from pytensor .misc .frozendict import frozendict
1918from pytensor .npy_2_compat import normalize_axis_tuple
2019from pytensor .printing import Printer , pprint
2120from pytensor .scalar import get_scalar_type
2221from pytensor .scalar .basic import identity as scalar_identity
23- from pytensor .scalar .basic import int64 , transfer_type , upcast
22+ from pytensor .scalar .basic import transfer_type , upcast
2423from pytensor .tensor import elemwise_cgen as cgen
2524from pytensor .tensor import get_vector_length
2625from pytensor .tensor .basic import _get_vector_length , as_tensor_variable
2928 continuous_dtypes ,
3029 discrete_dtypes ,
3130 float_dtypes ,
32- lvector ,
3331)
3432from pytensor .tensor .utils import (
3533 broadcast_static_dim_lengths ,
4038from pytensor .utils import uniq
4139
4240
43- class DimShuffle (ExternalCOp ):
41+ class DimShuffle (COp ):
4442 """
4543 Allows to reorder the dimensions of a tensor or insert or remove
4644 broadcastable dimensions.
@@ -114,74 +112,54 @@ class DimShuffle(ExternalCOp):
114112 _f16_ok = True
115113 check_input = False
116114 __props__ = ("input_ndim" , "new_order" )
117- c_func_file = "c_code/dimshuffle.c"
118- c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
119115 view_map = {0 : [0 ]}
120116
121- @property
122- def params_type (self ):
123- return ParamsType (
124- _new_order = lvector ,
125- input_ndim = int64 ,
126- )
127-
128117 def __init__ (self , * , input_ndim : int , new_order : Sequence [int | Literal ["x" ]]):
129- super ().__init__ ([self .c_func_file ], self .c_func_name )
130-
131118 if not isinstance (input_ndim , int ):
132119 raise TypeError (f"input_ndim must be an integer, got { type (int )} " )
133120
134121 self .input_ndim = input_ndim
135122 self .new_order = tuple (new_order )
136123 self ._new_order = [(- 1 if x == "x" else x ) for x in self .new_order ]
137124
138- for i , j in enumerate (new_order ):
139- if j != "x" :
140- if not isinstance (j , int | np .integer ):
141- raise TypeError (
142- "DimShuffle indices must be Python ints; got "
143- f"{ j } of type { type (j )} ."
144- )
145- if j >= input_ndim :
146- raise ValueError (
147- f"new_order[{ i } ] is { j } , but the input only has "
148- f"{ input_ndim } axes."
149- )
150- if j in new_order [(i + 1 ) :]:
151- raise ValueError (
152- "The same input dimension may not appear "
153- f"twice in the list of output dimensions: { new_order } "
154- )
155-
156125 # List of input dimensions to drop
157- drop = [i for i in range (input_ndim ) if i not in new_order ]
126+ self . drop = drop = [i for i in range (input_ndim ) if i not in new_order ]
158127
159128 # This is the list of the original dimensions that we keep
160- self .shuffle = [x for x in new_order if x != "x" ]
161- self .transposition = self .shuffle + drop
162- # List of dimensions of the output that are broadcastable and were not
163- # in the original input
164- self .augment = augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
165- self .drop = drop
129+ self .shuffle = shuffle = [x for x in new_order if x != "x" ]
130+
131+ # Input validation
132+ if not all (isinstance (x , int | np .integer ) for x in shuffle ):
133+ raise TypeError (
134+ "DimShuffle indices must be Python ints; got "
135+ f"{ shuffle } of type { [type (x ) for x in shuffle ]} ."
136+ )
137+ if len (shuffle ) != len (set (shuffle )):
138+ raise ValueError (
139+ f"Some dimensions were duplicated in new_order: { new_order } "
140+ )
141+ if max (shuffle , default = 0 ) > input_ndim :
142+ raise ValueError (
143+ f"Some dimensions in new_order are too large for input_ndim { input_ndim } : { new_order } "
144+ )
166145
167- dims_are_shuffled = sorted (self .shuffle ) != self .shuffle
146+ self .transposition = self .shuffle + drop
147+ # List of expand_dims positions
148+ self .augment = augment = [i for i , x in enumerate (new_order ) if x == "x" ]
168149
150+ # Properties that are useful for rewrites
151+ self .dims_are_shuffled = dims_are_shuffled = sorted (shuffle ) != shuffle
169152 self .is_transpose = dims_are_shuffled and not augment and not drop
170153 self .is_squeeze = drop and not dims_are_shuffled and not augment
171- self .is_expand_dims = augment and not dims_are_shuffled and not drop
172- self .is_left_expand_dims = self .is_expand_dims and (
154+ self .is_expand_dims = is_expand_dims = (
155+ augment and not dims_are_shuffled and not drop
156+ )
157+ self .is_left_expand_dims = is_expand_dims and (
173158 input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
174159 )
175- self .is_right_expand_dims = self .is_expand_dims and new_order [
176- :input_ndim
177- ] == list (range (input_ndim ))
178-
179- def __setstate__ (self , state ):
180- self .__dict__ .update (state )
181- if not hasattr (self , "func_files" ):
182- # Perhaps we are loading an old `Op` version of DimShuffle.
183- # Let's just build the ExternalCOp.
184- super ().__init__ ([self .c_func_file ], self .c_func_name )
160+ self .is_right_expand_dims = is_expand_dims and new_order [:input_ndim ] == list (
161+ range (input_ndim )
162+ )
185163
186164 def make_node (self , inp ):
187165 input = as_tensor_variable (inp )
@@ -193,22 +171,18 @@ def make_node(self, inp):
193171
194172 input_static_shape = input .type .shape
195173
196- # Runtime check for invalid drop
197- for d in self .drop :
198- if input_static_shape [d ] not in (1 , None ):
199- raise TypeError (
200- f"Input dropped dimension { d } must have length 1 but has { input_static_shape [d ]} "
201- )
202-
203- out_static_shape = []
204- for dim_idx in self .new_order :
205- if dim_idx == "x" :
206- out_static_shape .append (1 )
207- else :
208- out_static_shape .append (input_static_shape [dim_idx ])
209-
210- output = TensorType (dtype = input .type .dtype , shape = out_static_shape )()
174+ # Check for invalid drop
175+ if self .drop :
176+ for d in self .drop :
177+ if input_static_shape [d ] not in (1 , None ):
178+ raise TypeError (
179+ f"Input dropped dimension { d } must have length 1 but has { input_static_shape [d ]} "
180+ )
211181
182+ output = TensorType (
183+ dtype = input .type .dtype ,
184+ shape = [1 if d == "x" else input_static_shape [d ] for d in self .new_order ],
185+ )()
212186 return Apply (self , [input ], [output ])
213187
214188 def __str__ (self ):
@@ -273,6 +247,84 @@ def grad(self, inp, grads):
273247 else :
274248 return [gz .dimshuffle (grad_order )]
275249
250+ def c_code (self , node , name , input_names , output_names , sub ):
251+ [inp ] = input_names
252+ [out ] = output_names
253+ nd_in = node .inputs [0 ].ndim
254+ nd_out = node .outputs [0 ].ndim
255+ drop = self .drop
256+ fail = sub ["fail" ]
257+
258+ code = f"npy_intp dimensions[{ nd_out } ];\n "
259+ code += f"npy_intp strides[{ nd_out } ];\n "
260+
261+ code += dedent (
262+ f"""
263+ if (PyArray_NDIM({ inp } ) != { nd_in } ) {{
264+ PyErr_SetString(PyExc_ValueError, "ExpandDims: Input dimensions do not match expected.");
265+ { fail }
266+ }}
267+ """
268+ )
269+
270+ if drop :
271+ code += "npy_intp new_size = 1;\n "
272+ for i , o in enumerate (self .new_order ):
273+ if o == "x" :
274+ code += f"dimensions[{ i } ] = 1;\n "
275+ code += f"strides[{ i } ] = PyArray_ITEMSIZE({ inp } );\n "
276+ else :
277+ code += f"dimensions[{ i } ] = PyArray_DIMS({ inp } )[{ o } ];\n "
278+ code += f"strides[{ i } ] = PyArray_DIMS({ inp } )[{ o } ] == 1 ? PyArray_ITEMSIZE({ inp } ) : PyArray_STRIDES({ inp } )[{ o } ];\n "
279+ if drop :
280+ code += f"new_size *= dimensions[{ i } ];\n "
281+
282+ if drop :
283+ code += dedent (
284+ f"""
285+ if (PyArray_SIZE({ inp } ) != new_size) {{
286+ PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one.");
287+ { fail }
288+ }}
289+ """
290+ )
291+
292+ code += dedent (
293+ f"""
294+ Py_XDECREF({ out } );
295+
296+ Py_INCREF(PyArray_DESCR({ inp } ));
297+ { out } = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
298+ PyArray_DESCR({ inp } ),
299+ { nd_out } , dimensions,
300+ strides,
301+ PyArray_DATA({ inp } ),
302+ (PyArray_FLAGS({ inp } ) & ~NPY_ARRAY_OWNDATA),
303+ NULL);
304+
305+ if ({ out } == NULL) {{
306+ { fail }
307+ }}
308+
309+ // Declare it a view of the original input
310+ Py_INCREF((PyObject*){ inp } );
311+ PyArray_SetBaseObject({ out } , (PyObject*){ inp } );
312+ """
313+ )
314+
315+ if self .dims_are_shuffled :
316+ code += dedent (
317+ f"""
318+ // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
319+ PyArray_UpdateFlags({ out } , NPY_ARRAY_UPDATE_ALL);
320+ """
321+ )
322+
323+ return code
324+
325+ def c_code_cache_version (self ):
326+ return (0 ,)
327+
276328
277329class DimShufflePrinter (Printer ):
278330 def __p (self , new_order , pstate , r ):
0 commit comments