@@ -18,15 +18,13 @@ def squeeze(x, /, axis):
1818 ...
1919
2020"""
21- from collections import defaultdict
2221from copy import copy
2322from inspect import Parameter , Signature , signature
2423from itertools import chain
2524from types import FunctionType
26- from typing import Any , Callable , DefaultDict , Dict , List , Literal , Sequence , get_args
25+ from typing import Any , Callable , Dict , List , Literal , Sequence , get_args
2726
2827import pytest
29- from hypothesis import given , note
3028from hypothesis import strategies as st
3129
3230from . import dtype_helpers as dh
@@ -35,7 +33,7 @@ def squeeze(x, /, axis):
3533from ._array_module import _UndefinedStub
3634from ._array_module import mod as xp
3735from .stubs import array_methods , category_to_funcs , extension_to_funcs
38- from .typing import DataType , Shape
36+ from .typing import DataType
3937
4038pytestmark = pytest .mark .ci
4139
@@ -53,7 +51,7 @@ def squeeze(x, /, axis):
5351 Parameter .POSITIONAL_ONLY : "pos-only argument" ,
5452 Parameter .KEYWORD_ONLY : "keyword-only argument" ,
5553 Parameter .VAR_POSITIONAL : "star-args (i.e. *args) argument" ,
56- Parameter .VAR_KEYWORD : "star-kwargs (i.e. **kwargs ) argument" ,
54+ Parameter .VAR_KEYWORD : "star-kwonly (i.e. **kwonly ) argument" ,
5755}
5856
5957
@@ -63,14 +61,13 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
6361 # We're not interested if the array module has additional arguments, so we
6462 # only iterate through the arguments listed in the spec.
6563 for i , stub_param in enumerate (stub_params ):
66- if sig is not None :
67- assert (
68- len (params ) >= i + 1
69- ), f"Argument '{ stub_param .name } ' missing from signature"
70- param = params [i ]
64+ assert (
65+ len (params ) >= i + 1
66+ ), f"Argument '{ stub_param .name } ' missing from signature"
67+ param = params [i ]
7168
7269 # We're not interested in the name if it isn't actually used
73- if sig is not None and stub_param .kind not in [
70+ if stub_param .kind not in [
7471 Parameter .POSITIONAL_ONLY ,
7572 * VAR_KINDS ,
7673 ]:
@@ -80,50 +77,17 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
8077
8178 f_stub_kind = kind_to_str [stub_param .kind ]
8279 if stub_param .kind in [Parameter .POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
83- if sig is not None :
84- assert param .kind == stub_param .kind , (
85- f"{ param .name } is a { kind_to_str [param .kind ]} , "
86- f"but should be a { f_stub_kind } "
87- )
88- else :
89- pass
80+ assert param .kind == stub_param .kind , (
81+ f"{ param .name } is a { kind_to_str [param .kind ]} , "
82+ f"but should be a { f_stub_kind } "
83+ )
9084 else :
9185 # TODO: allow for kw-only args to be out-of-order
92- if sig is not None :
93- assert param .kind in [
94- stub_param .kind ,
95- Parameter .POSITIONAL_OR_KEYWORD ,
96- ], (
97- f"{ param .name } is a { kind_to_str [param .kind ]} , "
98- f"but should be a { f_stub_kind } "
99- f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
100- )
101- else :
102- pass
103-
104-
105- def shapes (** kw ) -> st .SearchStrategy [Shape ]:
106- if "min_side" not in kw .keys ():
107- kw ["min_side" ] = 1
108- return hh .shapes (** kw )
109-
110-
111- matrixy_funcs : List [str ] = [
112- f .__name__
113- for f in chain (category_to_funcs ["linear_algebra" ], extension_to_funcs ["linalg" ])
114- ]
115- matrixy_funcs += ["__matmul__" , "triu" , "tril" ]
116- func_to_shapes : DefaultDict [str , st .SearchStrategy [Shape ]] = defaultdict (
117- shapes ,
118- {
119- ** {k : st .just (()) for k in ["__bool__" , "__int__" , "__index__" , "__float__" ]},
120- "sort" : shapes (min_dims = 1 ), # for axis=-1,
121- ** {k : shapes (min_dims = 2 ) for k in matrixy_funcs },
122- # Overwrite min_dims=2 shapes for some matrixy functions
123- "cross" : shapes (min_side = 3 , max_side = 3 , min_dims = 3 , max_dims = 3 ),
124- "outer" : shapes (min_dims = 1 , max_dims = 1 ),
125- },
126- )
86+ assert param .kind in [stub_param .kind , Parameter .POSITIONAL_OR_KEYWORD ,], (
87+ f"{ param .name } is a { kind_to_str [param .kind ]} , "
88+ f"but should be a { f_stub_kind } "
89+ f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
90+ )
12791
12892
12993def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
@@ -136,97 +100,93 @@ def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
136100 return xps .scalar_dtypes ()
137101
138102
139- func_to_example_values : Dict [str , Dict [ParameterKind , Dict [str , Any ]]] = {
140- "broadcast_to" : {
141- Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([0 , 1 ])},
142- Parameter .POSITIONAL_OR_KEYWORD : {"shape" : (1 , 2 )},
143- },
144- "cholesky" : {
145- Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}
146- },
147- "inv" : {Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}},
148- }
149-
150-
151- def make_pretty_func (func_name : str , args : Sequence [Any ], kwargs : Dict [str , Any ]):
103+ def make_pretty_func (func_name : str , args : Sequence [Any ], kwonly : Dict [str , Any ]):
152104 f_sig = f"{ func_name } ("
153105 f_sig += ", " .join (str (a ) for a in args )
154- if len (kwargs ) != 0 :
106+ if len (kwonly ) != 0 :
155107 if len (args ) != 0 :
156108 f_sig += ", "
157- f_sig += ", " .join (f"{ k } ={ v } " for k , v in kwargs .items ())
109+ f_sig += ", " .join (f"{ k } ={ v } " for k , v in kwonly .items ())
158110 f_sig += ")"
159111 return f_sig
160112
161113
162- @given (data = st .data ())
163- def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature , data ):
164- example_values : Dict [ParameterKind , Dict [str , Any ]] = func_to_example_values .get (
165- func_name , {}
166- )
167- for kind in ALL_KINDS :
168- example_values .setdefault (kind , {})
114+ matrixy_funcs : List [str ] = [
115+ f .__name__
116+ for f in chain (category_to_funcs ["linear_algebra" ], extension_to_funcs ["linalg" ])
117+ ]
118+ matrixy_funcs += ["__matmul__" , "triu" , "tril" ]
169119
170- for param in stub_sig .parameters .values ():
171- for name_to_value in example_values .values ():
172- if param .name in name_to_value .keys ():
173- continue
174120
175- if param .default != Parameter .empty :
176- example_value = param .default
121+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature ):
122+ skip_msg = (
123+ f"Signature for { func_name } () is not inspectable "
124+ "and is too troublesome to test for otherwise"
125+ )
126+ if func_name in [
127+ "__bool__" ,
128+ "__int__" ,
129+ "__index__" ,
130+ "__float__" ,
131+ "pow" ,
132+ "bitwise_left_shift" ,
133+ "bitwise_right_shift" ,
134+ "broadcast_to" ,
135+ "permute_dims" ,
136+ "sort" ,
137+ * matrixy_funcs ,
138+ ]:
139+ pytest .skip (skip_msg )
140+
141+ param_to_value : Dict [Parameter , Any ] = {}
142+ for param in stub_sig .parameters .values ():
143+ if param .kind in VAR_KINDS :
144+ pytest .skip (skip_msg )
145+ elif param .default != Parameter .empty :
146+ value = param .default
177147 elif param .name in ["x" , "x1" ]:
178148 dtypes = get_dtypes_strategy (func_name )
179- shapes = func_to_shapes [func_name ]
180- example_value = data .draw (
181- xps .arrays (dtype = dtypes , shape = shapes ), label = param .name
182- )
149+ value = xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )).example ()
183150 elif param .name == "x2" :
184151 # sanity check
185- assert "x1" in example_values [Parameter .POSITIONAL_ONLY ].keys ()
186- x1 = example_values [Parameter .POSITIONAL_ONLY ]["x1" ]
187- example_value = data .draw (
188- xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = "x2"
189- )
190- elif param .name == "axes" :
191- example_value = ()
192- elif param .name == "shape" :
193- example_value = ()
152+ assert "x1" in [p .name for p in param_to_value .keys ()]
153+ x1 = next (v for p , v in param_to_value .items () if p .name == "x1" )
154+ value = xps .arrays (dtype = x1 .dtype , shape = x1 .shape ).example ()
194155 else :
195- pytest .skip (f"No example value for argument '{ param .name } '" )
196-
197- if param .kind in VAR_KINDS :
198- pytest .skip ("TODO" )
199- example_values [param .kind ][param .name ] = example_value
200-
201- if len (example_values [Parameter .POSITIONAL_OR_KEYWORD ]) == 0 :
202- f_func = make_pretty_func (
203- func_name ,
204- example_values [Parameter .POSITIONAL_ONLY ].values (),
205- example_values [Parameter .KEYWORD_ONLY ],
206- )
207- note (f"trying { f_func } " )
208- func (
209- * example_values [Parameter .POSITIONAL_ONLY ].values (),
210- ** example_values [Parameter .KEYWORD_ONLY ],
211- )
156+ pytest .skip (skip_msg )
157+ param_to_value [param ] = value
158+
159+ posonly : List [Any ] = [
160+ v for p , v in param_to_value .items () if p .kind == Parameter .POSITIONAL_ONLY
161+ ]
162+ kwonly : Dict [str , Any ] = {
163+ p .name : v for p , v in param_to_value .items () if p .kind == Parameter .KEYWORD_ONLY
164+ }
165+ if (
166+ sum (p .kind == Parameter .POSITIONAL_OR_KEYWORD for p in param_to_value .keys ())
167+ == 0
168+ ):
169+ f_func = make_pretty_func (func_name , posonly , kwonly )
170+ print (f"trying { f_func } " )
171+ func (* posonly , ** kwonly )
212172 else :
213173 either_argname_value_pairs = list (
214- example_values [Parameter .POSITIONAL_OR_KEYWORD ].items ()
174+ (p .name , v )
175+ for p , v in param_to_value .items ()
176+ if p .kind == Parameter .POSITIONAL_OR_KEYWORD
215177 )
216178 n_either_args = len (either_argname_value_pairs )
217179 for n_extra_args in reversed (range (n_either_args + 1 )):
218- extra_args = [v for _ , v in either_argname_value_pairs [:n_extra_args ]]
180+ extra_posargs = [v for _ , v in either_argname_value_pairs [:n_extra_args ]]
219181 if n_extra_args < n_either_args :
220182 extra_kwargs = dict (either_argname_value_pairs [n_extra_args :])
221183 else :
222184 extra_kwargs = {}
223- args = list (example_values [Parameter .POSITIONAL_ONLY ].values ())
224- args += extra_args
225- kwargs = copy (example_values [Parameter .KEYWORD_ONLY ])
226- if len (extra_kwargs ) != 0 :
227- kwargs .update (extra_kwargs )
185+ args = copy (posonly )
186+ args += extra_posargs
187+ kwargs = {** kwonly , ** extra_kwargs }
228188 f_func = make_pretty_func (func_name , args , kwargs )
229- note (f"trying { f_func } " )
189+ print (f"trying { f_func } " )
230190 func (* args , ** kwargs )
231191
232192
@@ -279,11 +239,9 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
279239
280240
281241@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
282- @given (data = st .data ())
283- def test_array_method_signature (stub : FunctionType , data ):
242+ def test_array_method_signature (stub : FunctionType ):
284243 dtypes = get_dtypes_strategy (stub .__name__ )
285- shapes = func_to_shapes [stub .__name__ ]
286- x = data .draw (xps .arrays (dtype = dtypes , shape = shapes ), label = "x" )
244+ x = xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )).example ()
287245 assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
288246 method = getattr (x , stub .__name__ )
289247 # Ignore 'self' arg in stub, which won't be present in instantiated objects.
0 commit comments