22
33import operator
44import warnings
5- from collections .abc import Callable
6- from typing import Any
75
8- if typing .TYPE_CHECKING :
9- from ._lib ._typing import Array , ModuleType
6+ # https://github.com/pylint-dev/pylint/issues/10112
7+ from collections .abc import Callable # pylint: disable=import-error
8+ from typing import ClassVar
109
1110from ._lib import _utils
1211from ._lib ._compat import (
1514 is_dask_array ,
1615 is_writeable_array ,
1716)
18- from ._lib ._typing import Array , ModuleType
17+ from ._lib ._typing import Array , Index , ModuleType , Untyped
1918
2019__all__ = [
2120 "at" ,
@@ -562,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
562561_undef = object ()
563562
564563
565- class at : # noqa: N801
564+ class at : # pylint: disable=invalid-name
566565 """
567566 Update operations for read-only arrays.
568567
@@ -654,14 +653,14 @@ class at: # noqa: N801
654653 """
655654
656655 x : Array
657- idx : Any
658- __slots__ = ("idx" , "x" )
656+ idx : Index
657+ __slots__ : ClassVar [ tuple [ str , str ]] = ("idx" , "x" )
659658
660- def __init__ (self , x : Array , idx : Any = _undef , / ):
659+ def __init__ (self , x : Array , idx : Index = _undef , / ):
661660 self .x = x
662661 self .idx = idx
663662
664- def __getitem__ (self , idx : Any ) -> Any :
663+ def __getitem__ (self , idx : Index ) -> at :
665664 """Allow for the alternate syntax ``at(x)[start:stop:step]``,
666665 which looks prettier than ``at(x, slice(start, stop, step))``
667666 and feels more intuitive coming from the JAX documentation.
@@ -680,8 +679,8 @@ def _common(
680679 copy : bool | None = True ,
681680 xp : ModuleType | None = None ,
682681 _is_update : bool = True ,
683- ** kwargs : Any ,
684- ) -> tuple [Any , None ] | tuple [None , Array ]:
682+ ** kwargs : Untyped ,
683+ ) -> tuple [Untyped , None ] | tuple [None , Array ]:
685684 """Perform common prepocessing.
686685
687686 Returns
@@ -709,11 +708,11 @@ def _common(
709708 if not writeable :
710709 msg = "Cannot modify parameter in place"
711710 raise ValueError (msg )
712- elif copy is None :
711+ elif copy is None : # type: ignore[redundant-expr]
713712 writeable = is_writeable_array (x )
714713 copy = _is_update and not writeable
715714 else :
716- msg = f"Invalid value for copy: { copy !r} " # type: ignore[unreachable]
715+ msg = f"Invalid value for copy: { copy !r} " # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
717716 raise ValueError (msg )
718717
719718 if copy :
@@ -744,7 +743,7 @@ def _common(
744743
745744 return None , x
746745
747- def get (self , ** kwargs : Any ) -> Any :
746+ def get (self , ** kwargs : Untyped ) -> Untyped :
748747 """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
749748 that the output is either a copy or a view; it also allows passing
750749 keyword arguments to the backend.
@@ -769,7 +768,7 @@ def get(self, **kwargs: Any) -> Any:
769768 assert x is not None
770769 return x [self .idx ]
771770
772- def set (self , y : Array , / , ** kwargs : Any ) -> Array :
771+ def set (self , y : Array , / , ** kwargs : Untyped ) -> Array :
773772 """Apply ``x[idx] = y`` and return the update array"""
774773 res , x = self ._common ("set" , y , ** kwargs )
775774 if res is not None :
@@ -784,7 +783,7 @@ def _iop(
784783 elwise_op : Callable [[Array , Array ], Array ],
785784 y : Array ,
786785 / ,
787- ** kwargs : Any ,
786+ ** kwargs : Untyped ,
788787 ) -> Array :
789788 """x[idx] += y or equivalent in-place operation on a subset of x
790789
@@ -802,33 +801,33 @@ def _iop(
802801 x [self .idx ] = elwise_op (x [self .idx ], y )
803802 return x
804803
805- def add (self , y : Array , / , ** kwargs : Any ) -> Array :
804+ def add (self , y : Array , / , ** kwargs : Untyped ) -> Array :
806805 """Apply ``x[idx] += y`` and return the updated array"""
807806 return self ._iop ("add" , operator .add , y , ** kwargs )
808807
809- def subtract (self , y : Array , / , ** kwargs : Any ) -> Array :
808+ def subtract (self , y : Array , / , ** kwargs : Untyped ) -> Array :
810809 """Apply ``x[idx] -= y`` and return the updated array"""
811810 return self ._iop ("subtract" , operator .sub , y , ** kwargs )
812811
813- def multiply (self , y : Array , / , ** kwargs : Any ) -> Array :
812+ def multiply (self , y : Array , / , ** kwargs : Untyped ) -> Array :
814813 """Apply ``x[idx] *= y`` and return the updated array"""
815814 return self ._iop ("multiply" , operator .mul , y , ** kwargs )
816815
817- def divide (self , y : Array , / , ** kwargs : Any ) -> Array :
816+ def divide (self , y : Array , / , ** kwargs : Untyped ) -> Array :
818817 """Apply ``x[idx] /= y`` and return the updated array"""
819818 return self ._iop ("divide" , operator .truediv , y , ** kwargs )
820819
821- def power (self , y : Array , / , ** kwargs : Any ) -> Array :
820+ def power (self , y : Array , / , ** kwargs : Untyped ) -> Array :
822821 """Apply ``x[idx] **= y`` and return the updated array"""
823822 return self ._iop ("power" , operator .pow , y , ** kwargs )
824823
825- def min (self , y : Array , / , ** kwargs : Any ) -> Array :
824+ def min (self , y : Array , / , ** kwargs : Untyped ) -> Array :
826825 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
827826 xp = array_namespace (self .x )
828827 y = xp .asarray (y )
829828 return self ._iop ("min" , xp .minimum , y , ** kwargs )
830829
831- def max (self , y : Array , / , ** kwargs : Any ) -> Array :
830+ def max (self , y : Array , / , ** kwargs : Untyped ) -> Array :
832831 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
833832 xp = array_namespace (self .x )
834833 y = xp .asarray (y )
0 commit comments