11import numpy as np
22
33from pytensor .gradient import grad_undefined
4- from pytensor .graph .basic import Apply , Constant
4+ from pytensor .graph .basic import Apply
55from pytensor .graph .op import Op
66from pytensor .misc .safe_asarray import _asarray
77from pytensor .tensor .basic import arange , as_tensor_variable , switch
8- from pytensor .tensor .math import eq , ge , mul
8+ from pytensor .tensor .math import eq , ge
99from pytensor .tensor .type import TensorType
1010
1111
12- def _variable_is_none (var ):
13- return isinstance (var , Constant ) and var .data is None
14-
15-
16- def _check_tensor_is_scalar (var ):
17- """
18- Checks if a tensor variable is scalar, raise ValueError otherwise
19- """
20- msg = "%(var)s is expected to be 0d tensor, got %(ndim)d"
21- if var .ndim != 0 :
22- raise ValueError (msg % (var , var .ndim ))
23-
24-
2512class SortOp (Op ):
2613 """
2714 This class is a wrapper for numpy sort function.
@@ -39,28 +26,16 @@ def __str__(self):
3926
4027 def make_node (self , input , axis = - 1 ):
4128 input = as_tensor_variable (input )
42- axis = as_tensor_variable (axis )
29+ axis = as_tensor_variable (axis , ndim = 0 , dtype = int )
4330 out_type = input .type ()
4431 return Apply (self , [input , axis ], [out_type ])
4532
4633 def perform (self , node , inputs , output_storage ):
47- a = inputs [0 ]
48- axis = inputs [1 ]
49- if axis is not None :
50- if axis != int (axis ):
51- raise ValueError ("sort axis must be an integer or None" )
52- axis = int (axis )
34+ a , axis = inputs
5335 z = output_storage [0 ]
54- z [0 ] = np .sort (a , axis , self .kind , self .order )
36+ z [0 ] = np .sort (a , int ( axis ) , self .kind , self .order )
5537
5638 def infer_shape (self , fgraph , node , inputs_shapes ):
57- if _variable_is_none (node .inputs [1 ]):
58- # That means axis = None,
59- # So the array is flattened before being sorted
60- return [(mul (* inputs_shapes [0 ]),)]
61- # axis should not be None
62- # So there should be the same number of dimensions
63- # in the input and output
6439 assert node .inputs [0 ].ndim == node .outputs [0 ].ndim
6540 assert inputs_shapes [1 ] == ()
6641 return [inputs_shapes [0 ]]
@@ -172,30 +147,22 @@ def __str__(self):
172147
173148 def make_node (self , input , axis = - 1 ):
174149 input = as_tensor_variable (input )
175- axis = as_tensor_variable (axis )
150+ axis = as_tensor_variable (axis , ndim = 0 , dtype = int )
176151 return Apply (
177152 self ,
178153 [input , axis ],
179154 [TensorType (dtype = "int64" , shape = input .type .shape )()],
180155 )
181156
182157 def perform (self , node , inputs , output_storage ):
183- a = inputs [0 ]
184- axis = inputs [1 ]
185- if axis is not None :
186- if axis != int (axis ):
187- raise ValueError ("sort axis must be an integer or None" )
188- axis = int (axis )
158+ a , axis = inputs
189159 z = output_storage [0 ]
190160 z [0 ] = _asarray (
191- np .argsort (a , axis , self .kind , self .order ), dtype = node .outputs [0 ].dtype
161+ np .argsort (a , int (axis ), self .kind , self .order ),
162+ dtype = node .outputs [0 ].dtype ,
192163 )
193164
194165 def infer_shape (self , fgraph , node , inputs_shapes ):
195- if _variable_is_none (node .inputs [1 ]):
196- return [(mul (* inputs_shapes [0 ]),)]
197- # axis should not be None, so there should be the same number of
198- # dimensions in the input and output
199166 assert node .inputs [0 ].ndim == node .outputs [0 ].ndim
200167 assert inputs_shapes [1 ] == ()
201168 return [inputs_shapes [0 ]]
@@ -239,66 +206,3 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
239206 a = a .flatten ()
240207 axis = 0
241208 return ArgSortOp (kind , order )(a , axis )
242-
243-
244- def _topk_py_impl (op , x , k , axis , idx_dtype ):
245- ndim = x .ndim
246- assert - ndim <= axis < ndim
247- axis %= ndim
248- if k == 0 :
249- raise ValueError ("topk: kth cannot be zero" )
250- elif k > x .shape [axis ]:
251- raise ValueError (
252- f"topk: kth cannot be larger than the size of specified axis { int (axis )} "
253- )
254- if abs (k ) == 1 :
255- # negative k means min instead of max
256- fn_max = [None , np .max , np .min ][k ]
257- fn_argmax = [None , np .argmax , np .argmin ][k ]
258- if not op .return_indices :
259- return np .expand_dims (fn_max (x , axis = axis ), axis )
260- elif op .return_values :
261- zi = np .expand_dims (fn_argmax (x , axis = axis ), axis )
262- idx2 = tuple (
263- np .arange (s ).reshape ((s ,) + (1 ,) * (ndim - i - 1 )) if i != axis else zi
264- for i , s in enumerate (x .shape )
265- )
266- zv = x [idx2 ]
267- return zv , zi .astype (idx_dtype )
268- else :
269- zi = np .expand_dims (fn_argmax (x , axis = axis ), axis )
270- return zi .astype (idx_dtype )
271-
272- if x .shape [axis ] == abs (k ):
273- if not op .return_indices :
274- return x .copy ()
275- else :
276- l = axis
277- r = ndim - l
278- reps = list (x .shape )
279- reps [axis ] = 1
280- zi = np .arange (abs (k ), dtype = idx_dtype )
281- zi = zi .reshape ((1 ,) * l + (k ,) + (1 ,) * (r - 1 ))
282- zi = np .tile (zi , reps )
283- if op .return_values :
284- return x .copy (), zi
285- else :
286- return zi
287-
288- idx = [slice (None )] * ndim
289- idx [axis ] = slice (- k , None ) if k > 0 else slice (- k )
290-
291- if not op .return_indices :
292- zv = np .partition (x , - k , axis = axis )[tuple (idx )]
293- return zv
294- elif op .return_values :
295- zi = np .argpartition (x , - k , axis = axis )[tuple (idx )]
296- idx2 = tuple (
297- np .arange (s ).reshape ((s ,) + (1 ,) * (ndim - i - 1 )) if i != axis else zi
298- for i , s in enumerate (x .shape )
299- )
300- zv = x [idx2 ]
301- return zv , zi .astype (idx_dtype )
302- else :
303- zi = np .argpartition (x , - k , axis = axis )[tuple (idx )]
304- return zi .astype (idx_dtype )
0 commit comments