File tree Expand file tree Collapse file tree 2 files changed +15
-4
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 2 files changed +15
-4
lines changed Original file line number Diff line number Diff line change 11from jax import numpy as jnp
22
33from pytensor .link .jax .dispatch import jax_funcify
4- from pytensor .tensor .sort import SortOp
4+ from pytensor .tensor .sort import ArgSortOp , SortOp
55
66
77@jax_funcify .register (SortOp )
@@ -12,3 +12,13 @@ def sort(arr, axis):
1212 return jnp .sort (arr , axis = axis , stable = stable )
1313
1414 return sort
15+
16+
17+ @jax_funcify .register (ArgSortOp )
18+ def jax_funcify_ArgSort (op , ** kwargs ):
19+ stable = op .kind == "stable"
20+
21+ def argsort (arr , axis ):
22+ return jnp .argsort (arr , axis = axis , stable = stable )
23+
24+ return argsort
Original file line number Diff line number Diff line change 33
44from pytensor .graph import FunctionGraph
55from pytensor .tensor import matrix
6- from pytensor .tensor .sort import sort
6+ from pytensor .tensor .sort import argsort , sort
77from tests .link .jax .test_basic import compare_jax_and_py
88
99
1010@pytest .mark .parametrize ("axis" , [None , - 1 ])
11- def test_sort (axis ):
11+ @pytest .mark .parametrize ("func" , (sort , argsort ))
12+ def test_sort (func , axis ):
1213 x = matrix ("x" , shape = (2 , 2 ), dtype = "float64" )
13- out = sort (x , axis = axis )
14+ out = func (x , axis = axis )
1415 fgraph = FunctionGraph ([x ], [out ])
1516 arr = np .array ([[1.0 , 4.0 ], [5.0 , 2.0 ]])
1617 compare_jax_and_py (fgraph , [arr ])
You can’t perform that action at this time.
0 commit comments