99from types import ModuleType
1010from typing import ClassVar , cast
1111
12- from ._utils ._compat import array_namespace , is_jax_array , is_writeable_array
12+ from ._utils ._compat import (
13+ array_namespace ,
14+ is_dask_array ,
15+ is_jax_array ,
16+ is_writeable_array ,
17+ )
1318from ._utils ._typing import Array , Index
1419
1520
@@ -141,6 +146,25 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
141146 not explicitly covered by ``array-api-compat``, are not supported by update
142147 methods.
143148
149+ Boolean masks are supported on Dask and jitted JAX arrays exclusively
150+ when `idx` has the same shape as `x` and `y` is 0-dimensional.
151+ Note that this support is not available in JAX's native
152+ ``x.at[mask].set(y)``.
153+
154+ This pattern::
155+
156+ >>> mask = m(x)
157+ >>> x[mask] = f(x[mask])
158+
159+ Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit::
160+
161+ >>> mask = m(x)
162+ >>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit
163+
164+ You should instead use::
165+
166+ >>> x = xp.where(m(x), f(x), x)
167+
144168 Examples
145169 --------
146170 Given either of these equivalent expressions::
@@ -189,6 +213,7 @@ def _op(
189213 self ,
190214 at_op : _AtOp ,
191215 in_place_op : Callable [[Array , Array | object ], Array ] | None ,
216+ out_of_place_op : Callable [[Array , Array ], Array ] | None ,
192217 y : Array | object ,
193218 / ,
194219 copy : bool | None ,
@@ -210,6 +235,16 @@ def _op(
210235
211236 x[idx] = y
212237
238+ out_of_place_op : Callable[[Array, Array], Array] | None
239+ Out-of-place operation to apply when idx is a boolean mask and the backend
240+ doesn't support in-place updates::
241+
242+ x = xp.where(idx, out_of_place_op(x, y), x)
243+
244+ If None::
245+
246+ x = xp.where(idx, y, x)
247+
213248 y : array or object
214249 Right-hand side of the operation.
215250 copy : bool or None
@@ -223,6 +258,7 @@ def _op(
223258 Updated `x`.
224259 """
225260 x , idx = self ._x , self ._idx
261+ xp = array_namespace (x , y ) if xp is None else xp
226262
227263 if idx is _undef :
228264 msg = (
@@ -247,15 +283,41 @@ def _op(
247283 else :
248284 writeable = is_writeable_array (x )
249285
286+ # JAX inside jax.jit and Dask don't support in-place updates with boolean
287+ # mask. However we can handle the common special case of 0-dimensional y
288+ # with where(idx, y, x) instead.
289+ if (
290+ (is_dask_array (idx ) or is_jax_array (idx ))
291+ and idx .dtype == xp .bool
292+ and idx .shape == x .shape
293+ ):
294+ y_xp = xp .asarray (y , dtype = x .dtype )
295+ if y_xp .ndim == 0 :
296+ if out_of_place_op :
297+ # FIXME: suppress inf warnings on dask with lazywhere
298+ out = xp .where (idx , out_of_place_op (x , y_xp ), x )
299+ # Undo int->float promotion on JAX after _AtOp.DIVIDE
300+ out = xp .astype (out , x .dtype , copy = False )
301+ else :
302+ out = xp .where (idx , y_xp , x )
303+
304+ if copy :
305+ return out
306+ x [()] = out
307+ return x
308+ # else: this will work on eager JAX and crash on jax.jit and Dask
309+
250310 if copy :
251311 if is_jax_array (x ):
252312 # Use JAX's at[]
253313 func = cast (Callable [[Array ], Array ], getattr (x .at [idx ], at_op .value ))
254- return func (y )
314+ out = func (y )
315+ # Undo int->float promotion on JAX after _AtOp.DIVIDE
316+ return xp .astype (out , x .dtype , copy = False )
317+
255318 # Emulate at[] behaviour for non-JAX arrays
256319 # with a copy followed by an update
257- if xp is None :
258- xp = array_namespace (x )
320+
259321 x = xp .asarray (x , copy = True )
260322 if writeable is False :
261323 # A copy of a read-only numpy array is writeable
@@ -283,7 +345,7 @@ def set(
283345 xp : ModuleType | None = None ,
284346 ) -> Array : # numpydoc ignore=PR01,RT01
285347 """Apply ``x[idx] = y`` and return the update array."""
286- return self ._op (_AtOp .SET , None , y , copy = copy , xp = xp )
348+ return self ._op (_AtOp .SET , None , None , y , copy = copy , xp = xp )
287349
288350 def add (
289351 self ,
@@ -297,7 +359,7 @@ def add(
297359 # Note for this and all other methods based on _iop:
298360 # operator.iadd and operator.add subtly differ in behaviour, as
299361 # only iadd will trigger exceptions when y has an incompatible dtype.
300- return self ._op (_AtOp .ADD , operator .iadd , y , copy = copy , xp = xp )
362+ return self ._op (_AtOp .ADD , operator .iadd , operator . add , y , copy = copy , xp = xp )
301363
302364 def subtract (
303365 self ,
@@ -307,7 +369,9 @@ def subtract(
307369 xp : ModuleType | None = None ,
308370 ) -> Array : # numpydoc ignore=PR01,RT01
309371 """Apply ``x[idx] -= y`` and return the updated array."""
310- return self ._op (_AtOp .SUBTRACT , operator .isub , y , copy = copy , xp = xp )
372+ return self ._op (
373+ _AtOp .SUBTRACT , operator .isub , operator .sub , y , copy = copy , xp = xp
374+ )
311375
312376 def multiply (
313377 self ,
@@ -317,7 +381,9 @@ def multiply(
317381 xp : ModuleType | None = None ,
318382 ) -> Array : # numpydoc ignore=PR01,RT01
319383 """Apply ``x[idx] *= y`` and return the updated array."""
320- return self ._op (_AtOp .MULTIPLY , operator .imul , y , copy = copy , xp = xp )
384+ return self ._op (
385+ _AtOp .MULTIPLY , operator .imul , operator .mul , y , copy = copy , xp = xp
386+ )
321387
322388 def divide (
323389 self ,
@@ -327,7 +393,9 @@ def divide(
327393 xp : ModuleType | None = None ,
328394 ) -> Array : # numpydoc ignore=PR01,RT01
329395 """Apply ``x[idx] /= y`` and return the updated array."""
330- return self ._op (_AtOp .DIVIDE , operator .itruediv , y , copy = copy , xp = xp )
396+ return self ._op (
397+ _AtOp .DIVIDE , operator .itruediv , operator .truediv , y , copy = copy , xp = xp
398+ )
331399
332400 def power (
333401 self ,
@@ -337,7 +405,7 @@ def power(
337405 xp : ModuleType | None = None ,
338406 ) -> Array : # numpydoc ignore=PR01,RT01
339407 """Apply ``x[idx] **= y`` and return the updated array."""
340- return self ._op (_AtOp .POWER , operator .ipow , y , copy = copy , xp = xp )
408+ return self ._op (_AtOp .POWER , operator .ipow , operator . pow , y , copy = copy , xp = xp )
341409
342410 def min (
343411 self ,
@@ -349,7 +417,7 @@ def min(
349417 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
350418 xp = array_namespace (self ._x ) if xp is None else xp
351419 y = xp .asarray (y )
352- return self ._op (_AtOp .MIN , xp .minimum , y , copy = copy , xp = xp )
420+ return self ._op (_AtOp .MIN , xp .minimum , xp . minimum , y , copy = copy , xp = xp )
353421
354422 def max (
355423 self ,
@@ -361,4 +429,4 @@ def max(
361429 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
362430 xp = array_namespace (self ._x ) if xp is None else xp
363431 y = xp .asarray (y )
364- return self ._op (_AtOp .MAX , xp .maximum , y , copy = copy , xp = xp )
432+ return self ._op (_AtOp .MAX , xp .maximum , xp . maximum , y , copy = copy , xp = xp )
0 commit comments