77"""
88from __future__ import annotations
99
10+ import operator
1011from typing import TYPE_CHECKING
1112
1213if TYPE_CHECKING :
13- from typing import Optional , Union , Any
14+ from typing import Callable , Optional , Union , Any
1415 from ._typing import Array , Device
1516
1617import sys
@@ -811,9 +812,22 @@ def is_writeable_array(x):
811812 return False
812813 return True
813814
815+ def _parse_copy_param (x , copy : bool | None ) -> bool :
816+ """Preprocess and validate a copy parameter, in line with the same
817+ parameter in np.asarray(), np.astype(), etc.
818+ """
819+ if copy is None :
820+ return not is_writeable_array (x )
821+ if copy is False :
822+ if not is_writeable_array (x ):
823+ raise ValueError ("Cannot avoid modifying parameter in place" )
824+ elif copy is not True :
825+ raise ValueError (f"Invalid value for copy: { copy !r} " )
826+ return copy
827+
814828_undef = object ()
815829
816- def at ( x , idx = _undef , / ) :
830+ class at :
817831 """
818832 Update operations for read-only arrays.
819833
@@ -823,12 +837,22 @@ def at(x, idx=_undef, /):
823837 Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
824838 quietly ignored for backends that don't support them.
825839
840+ Additionally, this introduces support for the `copy` keyword for all backends:
841+
842+ None
843+ x *may* be modified in place if it is possible and beneficial
844+ for performance. You should not use x after calling this function.
845+ True
846+ Ensure that the inputs are not modified. This is the default.
847+ False
848+ Raise ValueError if a copy cannot be avoided.
849+
826850 Examples
827851 --------
828852 Given either of these equivalent expressions::
829853
830- x = at(x)[1].add(2)
831- x = at(x, 1).add(2)
854+ x = at(x)[1].add(2, copy=None )
855+ x = at(x, 1).add(2, copy=None )
832856
833857 If x is a JAX array, they are the same as::
834858
@@ -845,16 +869,17 @@ def at(x, idx=_undef, /):
845869
846870 Warning
847871 -------
848- You should always immediately overwrite the parameter array::
872+ When you use copy=None, you should always immediately overwrite
873+ the parameter array::
849874
850- x = at(x, 0).set(2)
875+ x = at(x, 0).set(2, copy=None )
851876
852877 The anti-pattern below must be avoided, as it will result in different behaviour
853878 on read-only versus writeable arrays:
854879
855880 x = xp.asarray([0, 0, 0])
856- y = at(x, 0).set(2)
857- z = at(x, 1).set(3)
881+ y = at(x, 0).set(2, copy=None )
882+ z = at(x, 1).set(3, copy=None )
858883
859884 In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
860885 whereas y == z == [2, 3, 0] when x is writeable!
@@ -863,18 +888,6 @@ def at(x, idx=_undef, /):
863888 --------
864889 https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
865890 """
866- if is_jax_array (x ):
867- return x .at
868- if is_numpy_array (x ) and not x .flags .writeable :
869- x = x .copy ()
870- return _InPlaceAt (x , idx )
871-
872- class _InPlaceAt :
873- """Helper of at().
874-
875- Trivially implement jax.numpy.ndarray.at for other backends.
876- x is updated in place.
877- """
878891 __slots__ = ("x" , "idx" )
879892
880893 def __init__ (self , x , idx = _undef ):
@@ -890,7 +903,16 @@ def __getitem__(self, idx):
890903 self .idx = idx
891904 return self
892905
893- def _check_args (self , mode = "promise_in_bounds" , ** kwargs ):
906+ def _common (self , at_op , y = _undef , mode : str = "promise_in_bounds" , ** kwargs ):
907+ """Validate kwargs and perform common prepocessing.
908+
909+ Returns
910+ -------
911+ If the operation can be resolved by at[],
912+ (return value, None)
913+ Otherwise,
914+ (None, preprocessed x)
915+ """
894916 if self .idx is _undef :
895917 raise TypeError (
896918 "Index has not been set.\n "
@@ -900,74 +922,136 @@ def _check_args(self, mode="promise_in_bounds", **kwargs):
900922 " at(x)[idx].set(value)\n "
901923 "(same for all other methods)."
902924 )
903- if mode != "promise_in_bounds" :
925+ if mode != "promise_in_bounds" and not is_jax_array ( self . x ) :
904926 xp = array_namespace (self .x )
905927 raise NotImplementedError (
906- f"mode='{ mode } ' is not supported for backend { xp .__name__ } "
928+ f"mode='{ mode !r} ' is not supported for backend { xp .__name__ } "
929+ )
930+
931+ copy = _parse_copy_param (self .x , copy )
932+
933+ if copy and is_jax_array (self .x ):
934+ # Use JAX's at[]
935+ at_ = self .x .at [self .idx ]
936+ args = (y , ) if y is not _undef else ()
937+ return getattr (at_ , at_op )(* args , mode = mode , ** kwargs ), None
938+
939+ # Emulate at[] behaviour for non-JAX arrays
940+ x = self .x .copy () if copy else self .x
941+ return None , x
942+
943+ def get (self , copy : bool | None = True , ** kwargs ):
944+ """Return x[idx]. In addition to plain __getitem__, this allows ensuring
945+ that the output is (not) a copy and kwargs are passed to the backend."""
946+ # Special case when xp=numpy and idx is a fancy index
947+ # If copy is not False, avoid an unnecessary double copy.
948+ # if copy is forced to False, raise.
949+ if (
950+ is_numpy_array (self .x )
951+ and (
952+ isinstance (self .idx , (list , tuple ))
953+ or (is_numpy_array (self .idx ) and self .idx .dtype .kind in "biu" )
907954 )
955+ ):
956+ if copy is True :
957+ copy = None
958+ elif copy is False :
959+ raise ValueError (
960+ "Indexing a numpy array with a fancy index always "
961+ "results in a copy"
962+ )
963+
964+ res , x = self ._common ("get" , copy = copy , ** kwargs )
965+ if res is not None :
966+ return res
967+ return x [self .idx ]
908968
909969 def set (self , y , / , ** kwargs ):
910- self ._check_args (** kwargs )
911- self .x [self .idx ] = y
912- return self .x
970+ """x[idx] = y"""
971+ res , x = self ._common ("set" , y , ** kwargs )
972+ if res is not None :
973+ return res
974+ x [self .idx ] = y
975+ return x
976+
977+ def apply (self , ufunc , / , ** kwargs ):
978+ """ufunc.at(x, idx)"""
979+ res , x = self ._common ("apply" , ufunc , ** kwargs )
980+ if res is not None :
981+ return res
982+ ufunc .at (x , self .idx )
983+ return x
984+
985+ def _iop (self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs ):
986+ """x[idx] += y or equivalent in-place operation on a subset of x
987+
988+ which is the same as saying
989+ x[idx] = x[idx] + y
990+ Note that this is not the same as
991+ operator.iadd(x[idx], y)
992+ Consider for example when x is a numpy array and idx is a fancy index, which
993+ triggers a deep copy on __getitem__.
994+ """
995+ res , x = self ._common (at_op , y , ** kwargs )
996+ if res is not None :
997+ return res
998+ x [self .idx ] = elwise_op (x [self .idx ], y )
999+ return x
9131000
9141001 def add (self , y , / , ** kwargs ):
915- self ._check_args (** kwargs )
916- self .x [self .idx ] += y
917- return self .x
918-
1002+ """x[idx] += y"""
1003+ return self ._iop ("add" , operator .add , y , ** kwargs )
1004+
9191005 def subtract (self , y , / , ** kwargs ):
920- self ._check_args (** kwargs )
921- self .x [self .idx ] -= y
922- return self .x
1006+ """x[idx] -= y"""
1007+ return self ._iop ("subtract" , operator .sub , y , ** kwargs )
9231008
9241009 def multiply (self , y , / , ** kwargs ):
925- self ._check_args (** kwargs )
926- self .x [self .idx ] *= y
927- return self .x
1010+ """x[idx] *= y"""
1011+ return self ._iop ("multiply" , operator .mul , y , ** kwargs )
9281012
9291013 def divide (self , y , / , ** kwargs ):
930- self ._check_args (** kwargs )
931- self .x [self .idx ] /= y
932- return self .x
933-
1014+ """x[idx] /= y"""
1015+ return self ._iop ("divide" , operator .truediv , y , ** kwargs )
1016+
9341017 def power (self , y , / , ** kwargs ):
935- self ._check_args (** kwargs )
936- self .x [self .idx ] **= y
937- return self .x
1018+ """x[idx] **= y"""
1019+ return self ._iop ("power" , operator .pow , y , ** kwargs )
9381020
9391021 def min (self , y , / , ** kwargs ):
940- self ._check_args (** kwargs )
941- xp = array_namespace (self .x , y )
942- self .x [self .idx ] = xp .minimum (self .x [self .idx ], y )
943- return self .x
1022+ """x[idx] = minimum(x[idx], y)"""
1023+ xp = array_namespace (self .x )
1024+ return self ._iop ("min" , xp .minimum , y , ** kwargs )
9441025
9451026 def max (self , y , / , ** kwargs ):
946- self ._check_args (** kwargs )
947- xp = array_namespace (self .x , y )
948- self .x [self .idx ] = xp .maximum (self .x [self .idx ], y )
949- return self .x
950-
951- def apply (self , ufunc , / , ** kwargs ):
952- self ._check_args (** kwargs )
953- ufunc .at (self .x , self .idx )
954- return self .x
955-
956- def get (self , ** kwargs ):
957- self ._check_args (** kwargs )
958- return self .x [self .idx ]
959-
960- def iwhere (condition , x , y , / ):
961- """Variant of xp.where(condition, x, y) which may or may not update
962- x in place, if it's possible and beneficial for performance.
1027+ """x[idx] = maximum(x[idx], y)"""
1028+ xp = array_namespace (self .x )
1029+ return self ._iop ("max" , xp .maximum , y , ** kwargs )
1030+
1031+ def where (condition , x , y , / , copy : bool | None = True ):
1032+ """Return elements from x when condition is True and from y when
1033+ it is False.
1034+
1035+ This is a wrapper around xp.where that adds the copy parameter:
1036+
1037+ None
1038+ x *may* be modified in place if it is possible and beneficial
1039+ for performance. You should not use x after calling this function.
1040+ True
1041+ Ensure that the inputs are not modified.
1042+ This is the default, in line with np.where.
1043+ False
1044+ Raise ValueError if a copy cannot be avoided.
9631045 """
1046+ copy = _parse_copy_param (x , copy )
9641047 xp = array_namespace (condition , x , y )
965- if is_writeable_array (x ):
1048+ if copy :
1049+ return xp .where (condition , x , y )
1050+ else :
9661051 condition , x , y = xp .broadcast_arrays (condition , x , y )
9671052 x [condition ] = y [condition ]
9681053 return x
969- else :
970- return xp .where (condition , x , y )
1054+
9711055
9721056__all__ = [
9731057 "array_namespace" ,
@@ -993,7 +1077,7 @@ def iwhere(condition, x, y, /):
9931077 "size" ,
9941078 "to_device" ,
9951079 "at" ,
996- "iwhere " ,
1080+ "where " ,
9971081]
9981082
9991083_all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments