@@ -185,22 +185,42 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
185185 raise ValueError (msg )
186186 return at (self ._x , idx )
187187
188- def _update_common (
188+ def _op (
189189 self ,
190190 at_op : _AtOp ,
191- y : Array ,
191+ in_place_op : Callable [[Array , Array | object ], Array ] | None ,
192+ y : Array | object ,
192193 / ,
193194 copy : bool | None ,
194195 xp : ModuleType | None ,
195- ) -> tuple [ Array , None ] | tuple [ None , Array ]: # numpydoc ignore=PR01
196+ ) -> Array :
196197 """
197- Perform common prepocessing to all update operations.
198+ Implement all update operations.
199+
200+ Parameters
201+ ----------
202+ at_op : _AtOp
203+ Method of JAX's Array.at[].
204+ in_place_op : Callable[[Array, Array | object], Array] | None
205+ In-place operation to apply on mutable backends::
206+
207+ x[idx] = in_place_op(x[idx], y)
208+
209+ If None::
210+
211+ x[idx] = y
212+
213+ y : array or object
214+ Right-hand side of the operation.
215+ copy : bool or None
216+ Whether to copy the input array. See the class docstring for details.
217+ xp : array_namespace or None
218+ The array namespace for the input array.
198219
199220 Returns
200221 -------
201- tuple
202- If the operation can be resolved by ``at[]``, ``(return value, None)``
203- Otherwise, ``(None, preprocessed x)``.
222+ Array
223+ Updated `x`.
204224 """
205225 x , idx = self ._x , self ._idx
206226
@@ -231,7 +251,7 @@ def _update_common(
231251 if is_jax_array (x ):
232252 # Use JAX's at[]
233253 func = cast (Callable [[Array ], Array ], getattr (x .at [idx ], at_op .value ))
234- return func (y ), None
254+ return func (y )
235255 # Emulate at[] behaviour for non-JAX arrays
236256 # with a copy followed by an update
237257 if xp is None :
@@ -249,52 +269,25 @@ def _update_common(
249269 msg = f"Can't update read-only array { x } "
250270 raise ValueError (msg )
251271
252- return None , x
272+ if in_place_op :
273+ x [self ._idx ] = in_place_op (x [self ._idx ], y )
274+ else : # set()
275+ x [self ._idx ] = y
276+ return x
253277
254278 def set (
255279 self ,
256- y : Array ,
280+ y : Array | object ,
257281 / ,
258282 copy : bool | None = None ,
259283 xp : ModuleType | None = None ,
260284 ) -> Array : # numpydoc ignore=PR01,RT01
261285 """Apply ``x[idx] = y`` and return the update array."""
262- res , x = self ._update_common (_AtOp .SET , y , copy = copy , xp = xp )
263- if res is not None :
264- return res
265- assert x is not None
266- x [self ._idx ] = y
267- return x
268-
269- def _iop (
270- self ,
271- at_op : _AtOp ,
272- elwise_op : Callable [[Array , Array ], Array ],
273- y : Array ,
274- / ,
275- copy : bool | None ,
276- xp : ModuleType | None ,
277- ) -> Array : # numpydoc ignore=PR01,RT01
278- """
279- ``x[idx] += y`` or equivalent in-place operation on a subset of x.
280-
281- which is the same as saying
282- x[idx] = x[idx] + y
283- Note that this is not the same as
284- operator.iadd(x[idx], y)
285- Consider for example when x is a numpy array and idx is a fancy index, which
286- triggers a deep copy on __getitem__.
287- """
288- res , x = self ._update_common (at_op , y , copy = copy , xp = xp )
289- if res is not None :
290- return res
291- assert x is not None
292- x [self ._idx ] = elwise_op (x [self ._idx ], y )
293- return x
286+ return self ._op (_AtOp .SET , None , y , copy = copy , xp = xp )
294287
295288 def add (
296289 self ,
297- y : Array ,
290+ y : Array | object ,
298291 / ,
299292 copy : bool | None = None ,
300293 xp : ModuleType | None = None ,
@@ -304,70 +297,68 @@ def add(
304297 # Note for this and all other methods based on _iop:
305298 # operator.iadd and operator.add subtly differ in behaviour, as
306299 # only iadd will trigger exceptions when y has an incompatible dtype.
307- return self ._iop (_AtOp .ADD , operator .iadd , y , copy = copy , xp = xp )
300+ return self ._op (_AtOp .ADD , operator .iadd , y , copy = copy , xp = xp )
308301
309302 def subtract (
310303 self ,
311- y : Array ,
304+ y : Array | object ,
312305 / ,
313306 copy : bool | None = None ,
314307 xp : ModuleType | None = None ,
315308 ) -> Array : # numpydoc ignore=PR01,RT01
316309 """Apply ``x[idx] -= y`` and return the updated array."""
317- return self ._iop (_AtOp .SUBTRACT , operator .isub , y , copy = copy , xp = xp )
310+ return self ._op (_AtOp .SUBTRACT , operator .isub , y , copy = copy , xp = xp )
318311
319312 def multiply (
320313 self ,
321- y : Array ,
314+ y : Array | object ,
322315 / ,
323316 copy : bool | None = None ,
324317 xp : ModuleType | None = None ,
325318 ) -> Array : # numpydoc ignore=PR01,RT01
326319 """Apply ``x[idx] *= y`` and return the updated array."""
327- return self ._iop (_AtOp .MULTIPLY , operator .imul , y , copy = copy , xp = xp )
320+ return self ._op (_AtOp .MULTIPLY , operator .imul , y , copy = copy , xp = xp )
328321
329322 def divide (
330323 self ,
331- y : Array ,
324+ y : Array | object ,
332325 / ,
333326 copy : bool | None = None ,
334327 xp : ModuleType | None = None ,
335328 ) -> Array : # numpydoc ignore=PR01,RT01
336329 """Apply ``x[idx] /= y`` and return the updated array."""
337- return self ._iop (_AtOp .DIVIDE , operator .itruediv , y , copy = copy , xp = xp )
330+ return self ._op (_AtOp .DIVIDE , operator .itruediv , y , copy = copy , xp = xp )
338331
339332 def power (
340333 self ,
341- y : Array ,
334+ y : Array | object ,
342335 / ,
343336 copy : bool | None = None ,
344337 xp : ModuleType | None = None ,
345338 ) -> Array : # numpydoc ignore=PR01,RT01
346339 """Apply ``x[idx] **= y`` and return the updated array."""
347- return self ._iop (_AtOp .POWER , operator .ipow , y , copy = copy , xp = xp )
340+ return self ._op (_AtOp .POWER , operator .ipow , y , copy = copy , xp = xp )
348341
349342 def min (
350343 self ,
351- y : Array ,
344+ y : Array | object ,
352345 / ,
353346 copy : bool | None = None ,
354347 xp : ModuleType | None = None ,
355348 ) -> Array : # numpydoc ignore=PR01,RT01
356349 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
357- if xp is None :
358- xp = array_namespace (self ._x )
350+ xp = array_namespace (self ._x ) if xp is None else xp
359351 y = xp .asarray (y )
360- return self ._iop (_AtOp .MIN , xp .minimum , y , copy = copy , xp = xp )
352+ return self ._op (_AtOp .MIN , xp .minimum , y , copy = copy , xp = xp )
361353
362354 def max (
363355 self ,
364- y : Array ,
356+ y : Array | object ,
365357 / ,
366358 copy : bool | None = None ,
367359 xp : ModuleType | None = None ,
368360 ) -> Array : # numpydoc ignore=PR01,RT01
369361 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
370- if xp is None :
371- xp = array_namespace (self ._x )
362+ xp = array_namespace (self ._x ) if xp is None else xp
372363 y = xp .asarray (y )
373- return self ._iop (_AtOp .MAX , xp .maximum , y , copy = copy , xp = xp )
364+ return self ._op (_AtOp .MAX , xp .maximum , y , copy = copy , xp = xp )
0 commit comments