@@ -182,16 +182,17 @@ def test_jax_split_not_supported(self):
182182 UserWarning , match = "Split node does not have constant split positions."
183183 ):
184184 fn = pytensor .function ([a ], a_splits , mode = "JAX" )
185- # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
186- with pytest .raises (AttributeError ):
185+ # This test used to raise AttributeError in previous versions of JAX.
186+ # Now it raises `TracerIntegerConversionError`.
187+ # We accept both errors for backwards compatibility.
188+ with pytest .raises ((AttributeError , errors .TracerIntegerConversionError )):
187189 fn (np .zeros ((6 , 4 ), dtype = pytensor .config .floatX ))
188190
189191 split_axis = iscalar ("split_axis" )
190192 a_splits = ptb .split (a , splits_size = [2 , 4 ], n_splits = 2 , axis = split_axis )
191193 with pytest .warns (UserWarning , match = "Split node does not have constant axis." ):
192194 fn = pytensor .function ([a , split_axis ], a_splits , mode = "JAX" )
193- # Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
194- # Both errors are included for backwards compatibility
195+ # Same reasoning as above to accept both errors.
195196 with pytest .raises ((AttributeError , errors .TracerIntegerConversionError )):
196197 fn (np .zeros ((6 , 6 ), dtype = pytensor .config .floatX ), 0 )
197198
0 commit comments