44
55from functools import wraps
66from inspect import signature
7+ from typing import TYPE_CHECKING
78
8- def get_xp (xp ):
9+ __all__ = ["get_xp" ]
10+
11+ if TYPE_CHECKING :
12+ from collections .abc import Callable
13+ from types import ModuleType
14+ from typing import TypeVar
15+
16+ _T = TypeVar ("_T" )
17+
18+
19+ def get_xp (xp : "ModuleType" ) -> "Callable[[Callable[..., _T]], Callable[..., _T]]" :
920 """
1021 Decorator to automatically replace xp with the corresponding array module.
1122
@@ -22,14 +33,14 @@ def func(x, /, xp, kwarg=None):
2233
2334 """
2435
25- def inner (f ) :
36+ def inner (f : "Callable[..., _T]" , / ) -> "Callable[..., _T]" :
2637 @wraps (f )
27- def wrapped_f (* args , ** kwargs ) :
38+ def wrapped_f (* args : object , ** kwargs : object ) -> object :
2839 return f (* args , xp = xp , ** kwargs )
2940
3041 sig = signature (f )
3142 new_sig = sig .replace (
32- parameters = [sig . parameters [ i ] for i in sig .parameters if i != "xp" ]
43+ parameters = [par for i , par in sig .parameters . items () if i != "xp" ]
3344 )
3445
3546 if wrapped_f .__doc__ is None :
@@ -40,7 +51,7 @@ def wrapped_f(*args, **kwargs):
4051specification for more details.
4152
4253"""
43- wrapped_f .__signature__ = new_sig
44- return wrapped_f
54+ wrapped_f .__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
55+ return wrapped_f # pyright: ignore[reportReturnType]
4556
4657 return inner
0 commit comments