1+ import typing
2+
13import numpy as np
24
35from pytensor .gradient import grad_undefined
911from pytensor .tensor .type import TensorType
1012
1113
14+ KIND = typing .Literal ["quicksort" , "mergesort" , "heapsort" , "stable" ]
15+ KIND_VALUES = typing .get_args (KIND )
16+
17+
18+ def _parse_sort_args (kind : KIND | None , order , stable : bool | None ) -> KIND :
19+ if order is not None :
20+ raise ValueError ("The order argument is not applicable to PyTensor graphs" )
21+ if stable is not None and kind is not None :
22+ raise ValueError ("kind and stable cannot be set at the same time" )
23+ if stable :
24+ kind = "stable"
25+ elif kind is None :
26+ kind = "quicksort"
27+ if kind not in KIND_VALUES :
28+ raise ValueError (f"kind must be one of { KIND_VALUES } , got { kind } " )
29+ return kind
30+
31+
1232class SortOp (Op ):
1333 """
1434 This class is a wrapper for numpy sort function.
1535
1636 """
1737
18- __props__ = ("kind" , "order" )
38+ __props__ = ("kind" ,)
1939
20- def __init__ (self , kind , order = None ):
40+ def __init__ (self , kind : KIND ):
2141 self .kind = kind
22- self .order = order
23-
24- def __str__ (self ):
25- return self .__class__ .__name__ + f"{{{ self .kind } , { self .order } }}"
2642
2743 def make_node (self , input , axis = - 1 ):
2844 input = as_tensor_variable (input )
@@ -33,7 +49,7 @@ def make_node(self, input, axis=-1):
3349 def perform (self , node , inputs , output_storage ):
3450 a , axis = inputs
3551 z = output_storage [0 ]
36- z [0 ] = np .sort (a , int (axis ), self .kind , self . order )
52+ z [0 ] = np .sort (a , int (axis ), self .kind )
3753
3854 def infer_shape (self , fgraph , node , inputs_shapes ):
3955 assert node .inputs [0 ].ndim == node .outputs [0 ].ndim
@@ -75,9 +91,9 @@ def __get_argsort_indices(self, a, axis):
7591
7692 # The goal is to get gradient wrt input from gradient
7793 # wrt sort(input, axis)
78- idx = argsort (a , axis , kind = self .kind , order = self . order )
94+ idx = argsort (a , axis , kind = self .kind )
7995 # rev_idx is the reverse of previous argsort operation
80- rev_idx = argsort (idx , axis , kind = self .kind , order = self . order )
96+ rev_idx = argsort (idx , axis , kind = self .kind )
8197 indices = []
8298 axis_data = switch (ge (axis .data , 0 ), axis .data , a .ndim + axis .data )
8399 for i in range (a .ndim ):
@@ -101,7 +117,9 @@ def R_op(self, inputs, eval_points):
101117 """
102118
103119
104- def sort (a , axis = - 1 , kind = "quicksort" , order = None ):
120+ def sort (
121+ a , axis = - 1 , kind : KIND | None = None , order = None , * , stable : bool | None = None
122+ ):
105123 """
106124
107125 Parameters
@@ -111,23 +129,25 @@ def sort(a, axis=-1, kind="quicksort", order=None):
111129 axis: TensorVariable
112130 Axis along which to sort. If None, the array is flattened before
113131 sorting.
114- kind: {'quicksort', 'mergesort', 'heapsort'}, optional
115- Sorting algorithm. Default is 'quicksort'.
132+ kind: {'quicksort', 'mergesort', 'heapsort' 'stable' }, optional
133+ Sorting algorithm. Default is 'quicksort' unless stable is defined .
116134 order: list, optional
117- When `a` is a structured array, this argument specifies which
118- fields to compare first, second, and so on. This list does not
119- need to include all of the fields.
135+ For compatibility with numpy sort signature. Cannot be specified.
136+ stable: bool, optional
137+ Same as specifying kind = 'stable'. Cannot be specified at the same time as kind
120138
121139 Returns
122140 -------
123141 array
124142 A sorted copy of an array.
125143
126144 """
145+ kind = _parse_sort_args (kind , order , stable )
146+
127147 if axis is None :
128148 a = a .flatten ()
129149 axis = 0
130- return SortOp (kind , order )(a , axis )
150+ return SortOp (kind )(a , axis )
131151
132152
133153class ArgSortOp (Op ):
@@ -136,14 +156,10 @@ class ArgSortOp(Op):
136156
137157 """
138158
139- __props__ = ("kind" , "order" )
159+ __props__ = ("kind" ,)
140160
141- def __init__ (self , kind , order = None ):
161+ def __init__ (self , kind : KIND ):
142162 self .kind = kind
143- self .order = order
144-
145- def __str__ (self ):
146- return self .__class__ .__name__ + f"{{{ self .kind } , { self .order } }}"
147163
148164 def make_node (self , input , axis = - 1 ):
149165 input = as_tensor_variable (input )
@@ -158,7 +174,7 @@ def perform(self, node, inputs, output_storage):
158174 a , axis = inputs
159175 z = output_storage [0 ]
160176 z [0 ] = _asarray (
161- np .argsort (a , int (axis ), self .kind , self . order ),
177+ np .argsort (a , int (axis ), self .kind ),
162178 dtype = node .outputs [0 ].dtype ,
163179 )
164180
@@ -192,7 +208,9 @@ def R_op(self, inputs, eval_points):
192208 """
193209
194210
195- def argsort (a , axis = - 1 , kind = "quicksort" , order = None ):
211+ def argsort (
212+ a , axis = - 1 , kind : KIND | None = None , order = None , stable : bool | None = None
213+ ):
196214 """
197215 Returns the indices that would sort an array.
198216
@@ -202,7 +220,8 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
202220 order.
203221
204222 """
223+ kind = _parse_sort_args (kind , order , stable )
205224 if axis is None :
206225 a = a .flatten ()
207226 axis = 0
208- return ArgSortOp (kind , order )(a , axis )
227+ return ArgSortOp (kind )(a , axis )
0 commit comments