55from pytensor .configdefaults import config
66from pytensor .graph .fg import FunctionGraph
77from pytensor .tensor import subtensor as pt_subtensor
8+ from pytensor .tensor import tensor
89from pytensor .tensor .rewriting .jax import (
910 boolean_indexing_set_or_inc ,
1011 boolean_indexing_sum ,
1314
1415
1516def test_jax_Subtensor_constant ():
17+ shape = (3 , 4 , 5 )
18+ x_pt = tensor ("x" , shape = shape , dtype = "int" )
19+ x_np = np .arange (np .prod (shape )).reshape (shape )
20+
1621 # Basic indices
17- x_pt = pt .as_tensor (np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )))
1822 out_pt = x_pt [1 , 2 , 0 ]
1923 assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
20- out_fg = FunctionGraph ([], [out_pt ])
21- compare_jax_and_py (out_fg , [])
24+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
25+ compare_jax_and_py (out_fg , [x_np ])
2226
2327 out_pt = x_pt [1 :, 1 , :]
2428 assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
25- out_fg = FunctionGraph ([], [out_pt ])
26- compare_jax_and_py (out_fg , [])
29+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
30+ compare_jax_and_py (out_fg , [x_np ])
2731
2832 out_pt = x_pt [:2 , 1 , :]
2933 assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
30- out_fg = FunctionGraph ([], [out_pt ])
31- compare_jax_and_py (out_fg , [])
34+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
35+ compare_jax_and_py (out_fg , [x_np ])
3236
3337 out_pt = x_pt [1 :2 , 1 , :]
3438 assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
35- out_fg = FunctionGraph ([], [out_pt ])
36- compare_jax_and_py (out_fg , [])
39+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
40+ compare_jax_and_py (out_fg , [x_np ])
3741
3842 # Advanced indexing
3943 out_pt = pt_subtensor .advanced_subtensor1 (x_pt , [1 , 2 ])
4044 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor1 )
41- out_fg = FunctionGraph ([], [out_pt ])
42- compare_jax_and_py (out_fg , [])
45+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
46+ compare_jax_and_py (out_fg , [x_np ])
4347
4448 out_pt = x_pt [[1 , 2 ], [2 , 3 ]]
4549 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
46- out_fg = FunctionGraph ([], [out_pt ])
47- compare_jax_and_py (out_fg , [])
50+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
51+ compare_jax_and_py (out_fg , [x_np ])
4852
4953 # Advanced and basic indexing
5054 out_pt = x_pt [[1 , 2 ], :]
5155 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
52- out_fg = FunctionGraph ([], [out_pt ])
53- compare_jax_and_py (out_fg , [])
56+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
57+ compare_jax_and_py (out_fg , [x_np ])
5458
5559 out_pt = x_pt [[1 , 2 ], :, [3 , 4 ]]
5660 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
57- out_fg = FunctionGraph ([], [out_pt ])
58- compare_jax_and_py (out_fg , [])
61+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
62+ compare_jax_and_py (out_fg , [x_np ])
5963
6064 # Flipping
6165 out_pt = x_pt [::- 1 ]
62- out_fg = FunctionGraph ([], [out_pt ])
63- compare_jax_and_py (out_fg , [])
66+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
67+ compare_jax_and_py (out_fg , [x_np ])
68+
69+ # Boolean indexing should work if indexes are constant
70+ out_pt = x_pt [np .random .binomial (1 , 0.5 , size = (3 , 4 , 5 ))]
71+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
72+ compare_jax_and_py (out_fg , [x_np ])
6473
6574
6675@pytest .mark .xfail (reason = "`a` should be specified as static when JIT-compiling" )
@@ -73,16 +82,18 @@ def test_jax_Subtensor_dynamic():
7382 compare_jax_and_py (out_fg , [1 ])
7483
7584
76- def test_jax_Subtensor_boolean_mask ():
77- """JAX does not support resizing arrays with boolean masks."""
85+ def test_jax_Subtensor_dynamic_boolean_mask ():
86+ """JAX does not support resizing arrays with dynamic boolean masks."""
87+ from jax .errors import NonConcreteBooleanIndexError
88+
7889 x_pt = pt .vector ("x" , dtype = "float64" )
7990 out_pt = x_pt [x_pt < 0 ]
8091 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
8192
8293 out_fg = FunctionGraph ([x_pt ], [out_pt ])
8394
8495 x_pt_test = np .arange (- 5 , 5 )
85- with pytest .raises (NotImplementedError , match = "resizing arrays with boolean" ):
96+ with pytest .raises (NonConcreteBooleanIndexError ):
8697 compare_jax_and_py (out_fg , [x_pt_test ])
8798
8899
0 commit comments