55from pytensor .configdefaults import config
66from pytensor .graph .fg import FunctionGraph
77from pytensor .tensor import subtensor as at_subtensor
8- from pytensor .tensor .rewriting .jax import boolean_indexing_sum
8+ from pytensor .tensor .rewriting .jax import (
9+ boolean_indexing_set_or_inc ,
10+ boolean_indexing_sum ,
11+ )
912from tests .link .jax .test_basic import compare_jax_and_py
1013
1114
@@ -216,7 +219,7 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
216219
217220 This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
218221 replaced with an equivalent `Switch` `Op`, using the
219- `jax_boolean_indexing_set_of_inc ` rewrite.
222+ `boolean_indexing_set_of_inc ` rewrite.
220223
221224 JAX forces users to re-express this logic manually, so this is an
222225 improvement over its user interface.
@@ -237,3 +240,12 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
237240 assert isinstance (out_at .owner .op , at_subtensor .AdvancedIncSubtensor )
238241 out_fg = FunctionGraph ([x_at ], [out_at ])
239242 compare_jax_and_py (out_fg , [x_np ])
243+
244+
245+ def test_boolean_indexing_set_or_inc_not_applicable ():
246+ """Test that `boolean_indexing_set_or_inc` does not return an invalid replacement in cases where it doesn't apply."""
247+ x = at .vector ("x" )
248+ mask = at .as_tensor (x ) > 0
249+ out = at_subtensor .set_subtensor (x [mask ], [0 , 1 , 2 ])
250+ fg = FunctionGraph ([x ], [out ])
251+ assert boolean_indexing_set_or_inc .transform (fg , fg .outputs [0 ].owner ) is None
0 commit comments