1+ """
2+ We're not interested in being 100% strict - instead we focus on areas which
3+ could affect interop, e.g. with
4+
5+ def add(x1, x2, /):
6+ ...
7+
8+ x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
9+
10+ def squeeze(x, /, axis):
11+ ...
12+
13+ axis has to be pos-or-keyword to support both styles
14+
15+ >>> squeeze(x, 0)
16+ ...
17+ >>> squeeze(x, axis=0)
18+ ...
19+
20+ """
21+ from collections import defaultdict
122from inspect import Parameter , Signature , signature
23+ from itertools import chain
224from types import FunctionType
3- from typing import Callable , Dict
25+ from typing import Callable , DefaultDict , Dict , List
426
527import pytest
628from hypothesis import given
29+ from hypothesis import strategies as st
730
31+ from . import dtype_helpers as dh
832from . import hypothesis_helpers as hh
933from . import xps
34+ from ._array_module import _UndefinedStub
1035from ._array_module import mod as xp
1136from .stubs import array_methods , category_to_funcs , extension_to_funcs
37+ from .typing import DataType , Shape
1238
1339pytestmark = pytest .mark .ci
1440
41+
1542kind_to_str : Dict [Parameter , str ] = {
1643 Parameter .POSITIONAL_OR_KEYWORD : "normal argument" ,
1744 Parameter .POSITIONAL_ONLY : "pos-only argument" ,
2047 Parameter .VAR_KEYWORD : "star-kwargs (i.e. **kwargs) argument" ,
2148}
2249
50+ VAR_KINDS = (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD )
2351
24- def _test_signature (
25- func : Callable , stub : FunctionType , ignore_first_stub_param : bool = False
26- ):
27- """
28- Signature of function is correct enough to not affect interoperability
29-
30- We're not interested in being 100% strict - instead we focus on areas which
31- could affect interop, e.g. with
32-
33- def add(x1, x2, /):
34- ...
3552
36- x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
37-
38- def squeeze(x, /, axis):
39- ...
40-
41- axis has to be pos-or-keyword to support both styles
42-
43- >>> squeeze(x, 0)
44- ...
45- >>> squeeze(x, axis=0)
46- ...
47-
48- """
49- try :
50- sig = signature (func )
51- except ValueError :
52- pytest .skip (
53- msg = f"type({ stub .__name__ } )={ type (func )} not supported by inspect.signature()"
54- )
53+ def _test_inspectable_func (sig : Signature , stub_sig : Signature ):
5554 params = list (sig .parameters .values ())
56-
57- stub_sig = signature (stub )
5855 stub_params = list (stub_sig .parameters .values ())
59- if ignore_first_stub_param :
60- stub_params = stub_params [1 :]
61- stub = Signature (
62- parameters = stub_params , return_annotation = stub_sig .return_annotation
63- )
64-
6556 # We're not interested if the array module has additional arguments, so we
6657 # only iterate through the arguments listed in the spec.
6758 for i , stub_param in enumerate (stub_params ):
68- assert (
69- len (params ) >= i + 1
70- ), f"Argument '{ stub_param .name } ' missing from signature"
71- param = params [i ]
59+ if sig is not None :
60+ assert (
61+ len (params ) >= i + 1
62+ ), f"Argument '{ stub_param .name } ' missing from signature"
63+ param = params [i ]
7264
7365 # We're not interested in the name if it isn't actually used
74- if stub_param .kind not in [
66+ if sig is not None and stub_param .kind not in [
7567 Parameter .POSITIONAL_ONLY ,
76- Parameter .VAR_POSITIONAL ,
77- Parameter .VAR_KEYWORD ,
68+ * VAR_KINDS ,
7869 ]:
7970 assert (
8071 param .name == stub_param .name
8172 ), f"Expected argument '{ param .name } ' to be named '{ stub_param .name } '"
8273
83- if (
84- stub_param .name in ["x" , "x1" , "x2" ]
85- and stub_param .kind != Parameter .POSITIONAL_ONLY
86- ):
87- pytest .skip (
88- f"faulty spec - argument { stub_param .name } should be a "
89- f"{ kind_to_str [Parameter .POSITIONAL_ONLY ]} "
90- )
91- f_kind = kind_to_str [param .kind ]
9274 f_stub_kind = kind_to_str [stub_param .kind ]
93- if stub_param .kind in [
94- Parameter . POSITIONAL_OR_KEYWORD ,
95- Parameter . VAR_POSITIONAL ,
96- Parameter . VAR_KEYWORD ,
97- ]:
98- assert (
99- param . kind == stub_param . kind
100- ), f" { param . name } is a { f_kind } , but should be a { f_stub_kind } "
75+ if stub_param .kind in [Parameter . POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
76+ if sig is not None :
77+ assert param . kind == stub_param . kind , (
78+ f" { param . name } is a { kind_to_str [ param . kind ] } , "
79+ f"but should be a { f_stub_kind } "
80+ )
81+ else :
82+ pass
10183 else :
10284 # TODO: allow for kw-only args to be out-of-order
103- assert param .kind in [stub_param .kind , Parameter .POSITIONAL_OR_KEYWORD ], (
104- f"{ param .name } is a { f_kind } , "
105- f"but should be a { f_stub_kind } "
106- f"(or at least a { kind_to_str [Parameter .POSITIONAL_OR_KEYWORD ]} )"
85+ if sig is not None :
86+ assert param .kind in [
87+ stub_param .kind ,
88+ Parameter .POSITIONAL_OR_KEYWORD ,
89+ ], (
90+ f"{ param .name } is a { kind_to_str [param .kind ]} , "
91+ f"but should be a { f_stub_kind } "
92+ f"(or at least a { kind_to_str [Parameter .POSITIONAL_OR_KEYWORD ]} )"
93+ )
94+ else :
95+ pass
96+
97+ def shapes (** kw ) -> st .SearchStrategy [Shape ]:
98+ if "min_side" not in kw .keys ():
99+ kw ["min_side" ] = 1
100+ return hh .shapes (** kw )
101+
102+
103+ matrixy_funcs : List [str ] = [
104+ f .__name__
105+ for f in chain (category_to_funcs ["linear_algebra" ], extension_to_funcs ["linalg" ])
106+ ]
107+ matrixy_funcs += ["__matmul__" , "triu" , "tril" ]
108+ func_to_shapes : DefaultDict [str , st .SearchStrategy [Shape ]] = defaultdict (
109+ shapes ,
110+ {
111+ ** {k : st .just (()) for k in ["__bool__" , "__int__" , "__index__" , "__float__" ]},
112+ "sort" : shapes (min_dims = 1 ), # for axis=-1,
113+ ** {k : shapes (min_dims = 2 ) for k in matrixy_funcs },
114+ # Override for some matrixy functions
115+ "cross" : shapes (min_side = 3 , max_side = 3 , min_dims = 3 , max_dims = 3 ),
116+ "outer" : shapes (min_dims = 1 , max_dims = 1 ),
117+ },
118+ )
119+
120+
121+ def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
122+ if func_name in dh .func_in_dtypes .keys ():
123+ dtypes = dh .func_in_dtypes [func_name ]
124+ if hh .FILTER_UNDEFINED_DTYPES :
125+ dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
126+ return st .sampled_from (dtypes )
127+ else :
128+ return xps .scalar_dtypes ()
129+
130+
131+ @given (data = st .data ())
132+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature , data ):
133+ if func_name in ["cholesky" , "inv" ]:
134+ func (xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]]))
135+ return
136+ elif func_name == "solve" :
137+ func (xp .asarray ([[1.0 , 2.0 ], [3.0 , 5.0 ]]), xp .asarray ([1.0 , 2.0 ]))
138+ return
139+
140+ pos_argname_to_example_value = {}
141+ normal_argname_to_example_value = {}
142+ kw_argname_to_example_value = {}
143+ for stub_param in stub_sig .parameters .values ():
144+ if stub_param .name in ["x" , "x1" ]:
145+ dtypes = get_dtypes_strategy (func_name )
146+ shapes = func_to_shapes [func_name ]
147+ example_value = data .draw (
148+ xps .arrays (dtype = dtypes , shape = shapes ), label = stub_param .name
107149 )
150+ elif stub_param .name == "x2" :
151+ assert "x1" in pos_argname_to_example_value .keys () # sanity check
152+ x1 = pos_argname_to_example_value ["x1" ]
153+ example_value = data .draw (
154+ xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = "x2"
155+ )
156+ else :
157+ if stub_param .default != Parameter .empty :
158+ example_value = stub_param .default
159+ else :
160+ pytest .skip (f"No example value for argument '{ stub_param .name } '" )
161+
162+ if stub_param .kind == Parameter .POSITIONAL_ONLY :
163+ pos_argname_to_example_value [stub_param .name ] = example_value
164+ elif stub_param .kind == Parameter .POSITIONAL_OR_KEYWORD :
165+ normal_argname_to_example_value [stub_param .name ] = example_value
166+ elif stub_param .kind == Parameter .KEYWORD_ONLY :
167+ kw_argname_to_example_value [stub_param .name ] = example_value
168+ else :
169+ pytest .skip ()
170+
171+ if len (normal_argname_to_example_value ) == 0 :
172+ func (* pos_argname_to_example_value .values (), ** kw_argname_to_example_value )
173+ else :
174+ pass # TODO
175+
176+
177+ def _test_func_signature (
178+ func : Callable , stub : FunctionType , ignore_first_stub_param : bool = False
179+ ):
180+ stub_sig = signature (stub )
181+ if ignore_first_stub_param :
182+ stub_params = list (stub_sig .parameters .values ())
183+ del stub_params [0 ]
184+ stub_sig = Signature (
185+ parameters = stub_params , return_annotation = stub_sig .return_annotation
186+ )
187+
188+ try :
189+ sig = signature (func )
190+ _test_inspectable_func (sig , stub_sig )
191+ except ValueError :
192+ _test_uninspectable_func (stub .__name__ , func , stub_sig )
108193
109194
110195@pytest .mark .parametrize (
@@ -115,7 +200,7 @@ def squeeze(x, /, axis):
115200def test_func_signature (stub : FunctionType ):
116201 assert hasattr (xp , stub .__name__ ), f"{ stub .__name__ } not found in array module"
117202 func = getattr (xp , stub .__name__ )
118- _test_signature (func , stub )
203+ _test_func_signature (func , stub )
119204
120205
121206extension_and_stub_params = []
@@ -134,13 +219,16 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
134219 mod , stub .__name__
135220 ), f"{ stub .__name__ } not found in { extension } extension"
136221 func = getattr (mod , stub .__name__ )
137- _test_signature (func , stub )
222+ _test_func_signature (func , stub )
138223
139224
140225@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
141- @given (x = xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes ()))
142- def test_array_method_signature (stub : FunctionType , x ):
226+ @given (data = st .data ())
227+ def test_array_method_signature (stub : FunctionType , data ):
228+ dtypes = get_dtypes_strategy (stub .__name__ )
229+ shapes = func_to_shapes [stub .__name__ ]
230+ x = data .draw (xps .arrays (dtype = dtypes , shape = shapes ), label = "x" )
143231 assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
144232 method = getattr (x , stub .__name__ )
145233 # Ignore 'self' arg in stub, which won't be present in instantiated objects.
146- _test_signature (method , stub , ignore_first_stub_param = True )
234+ _test_func_signature (method , stub , ignore_first_stub_param = True )
0 commit comments