@@ -33,7 +33,7 @@ def squeeze(x, /, axis):
3333from ._array_module import _UndefinedStub
3434from ._array_module import mod as xp
3535from .stubs import array_methods , category_to_funcs , extension_to_funcs
36- from .typing import DataType
36+ from .typing import Array , DataType
3737
3838pytestmark = pytest .mark .ci
3939
@@ -112,7 +112,8 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
112112
113113
114114matrixy_funcs : List [FunctionType ] = [
115- * category_to_funcs ["linear_algebra" ], * extension_to_funcs ["linalg" ]
115+ * category_to_funcs ["linear_algebra" ],
116+ * extension_to_funcs ["linalg" ],
116117]
117118matrixy_names : List [str ] = [f .__name__ for f in matrixy_funcs ]
118119matrixy_names += ["__matmul__" , "triu" , "tril" ]
@@ -121,7 +122,7 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
121122@given (data = st .data ())
122123@settings (max_examples = 1 )
123124def _test_uninspectable_func (
124- func_name : str , func : Callable , stub_sig : Signature , data : DataObject
125+ func_name : str , func : Callable , stub_sig : Signature , array : Array , data : DataObject
125126):
126127 skip_msg = (
127128 f"Signature for { func_name } () is not inspectable "
@@ -153,12 +154,15 @@ def _test_uninspectable_func(
153154 value = data .draw (
154155 xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = param .name
155156 )
156- elif param .name == "x2" :
157- # sanity check
158- assert "x1" in [p .name for p in param_to_value .keys ()]
159- x1 = next (v for p , v in param_to_value .items () if p .name == "x1" )
157+ elif param .name in ["x2" , "other" ]:
158+ if param .name == "x2" :
159+ assert "x1" in [p .name for p in param_to_value .keys ()] # sanity check
160+ orig = next (v for p , v in param_to_value .items () if p .name == "x1" )
161+ else :
162+ assert array is not None # sanity check
163+ orig = array
160164 value = data .draw (
161- xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = param .name
165+ xps .arrays (dtype = orig .dtype , shape = orig .shape ), label = param .name
162166 )
163167 else :
164168 pytest .skip (
@@ -177,11 +181,11 @@ def _test_uninspectable_func(
177181 func (* args , ** kwargs )
178182
179183
180- def _test_func_signature (
181- func : Callable , stub : FunctionType , ignore_first_stub_param : bool = False
182- ):
184+ def _test_func_signature (func : Callable , stub : FunctionType , array = None ):
183185 stub_sig = signature (stub )
184- if ignore_first_stub_param :
186+ # If testing against array, ignore 'self' arg in stub as it won't be present
187+ # in func (which should be an array method).
188+ if array is not None :
185189 stub_params = list (stub_sig .parameters .values ())
186190 del stub_params [0 ]
187191 stub_sig = Signature (
@@ -192,7 +196,7 @@ def _test_func_signature(
192196 sig = signature (func )
193197 _test_inspectable_func (sig , stub_sig )
194198 except ValueError :
195- _test_uninspectable_func (stub .__name__ , func , stub_sig )
199+ _test_uninspectable_func (stub .__name__ , func , stub_sig , array )
196200
197201
198202@pytest .mark .parametrize (
@@ -233,5 +237,4 @@ def test_array_method_signature(stub: FunctionType, data: DataObject):
233237 x = data .draw (xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = "x" )
234238 assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
235239 method = getattr (x , stub .__name__ )
236- # Ignore 'self' arg in stub, which won't be present in instantiated objects.
237- _test_func_signature (method , stub , ignore_first_stub_param = True )
240+ _test_func_signature (method , stub , array = x )
0 commit comments