@@ -19,13 +19,14 @@ def squeeze(x, /, axis):
1919
2020"""
2121from collections import defaultdict
22+ from copy import copy
2223from inspect import Parameter , Signature , signature
2324from itertools import chain
2425from types import FunctionType
25- from typing import Callable , DefaultDict , Dict , List
26+ from typing import Any , Callable , DefaultDict , Dict , List , Literal , Sequence , get_args
2627
2728import pytest
28- from hypothesis import given
29+ from hypothesis import given , note
2930from hypothesis import strategies as st
3031
3132from . import dtype_helpers as dh
@@ -38,17 +39,23 @@ def squeeze(x, /, axis):
3839
3940pytestmark = pytest .mark .ci
4041
41-
42- kind_to_str : Dict [Parameter , str ] = {
42+ ParameterKind = Literal [
43+ Parameter .POSITIONAL_ONLY ,
44+ Parameter .VAR_POSITIONAL ,
45+ Parameter .POSITIONAL_OR_KEYWORD ,
46+ Parameter .KEYWORD_ONLY ,
47+ Parameter .VAR_KEYWORD ,
48+ ]
49+ ALL_KINDS = get_args (ParameterKind )
50+ VAR_KINDS = (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD )
51+ kind_to_str : Dict [ParameterKind , str ] = {
4352 Parameter .POSITIONAL_OR_KEYWORD : "normal argument" ,
4453 Parameter .POSITIONAL_ONLY : "pos-only argument" ,
4554 Parameter .KEYWORD_ONLY : "keyword-only argument" ,
4655 Parameter .VAR_POSITIONAL : "star-args (i.e. *args) argument" ,
4756 Parameter .VAR_KEYWORD : "star-kwargs (i.e. **kwargs) argument" ,
4857}
4958
50- VAR_KINDS = (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD )
51-
5259
5360def _test_inspectable_func (sig : Signature , stub_sig : Signature ):
5461 params = list (sig .parameters .values ())
@@ -89,11 +96,12 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
8996 ], (
9097 f"{ param .name } is a { kind_to_str [param .kind ]} , "
9198 f"but should be a { f_stub_kind } "
92- f"(or at least a { kind_to_str [Parameter .POSITIONAL_OR_KEYWORD ]} )"
99+ f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
93100 )
94101 else :
95102 pass
96103
104+
97105def shapes (** kw ) -> st .SearchStrategy [Shape ]:
98106 if "min_side" not in kw .keys ():
99107 kw ["min_side" ] = 1
@@ -111,7 +119,7 @@ def shapes(**kw) -> st.SearchStrategy[Shape]:
111119 ** {k : st .just (()) for k in ["__bool__" , "__int__" , "__index__" , "__float__" ]},
112120 "sort" : shapes (min_dims = 1 ), # for axis=-1,
113121 ** {k : shapes (min_dims = 2 ) for k in matrixy_funcs },
114- # Override for some matrixy functions
122+ # Overwrite min_dims=2 shapes for some matrixy functions
115123 "cross" : shapes (min_side = 3 , max_side = 3 , min_dims = 3 , max_dims = 3 ),
116124 "outer" : shapes (min_dims = 1 , max_dims = 1 ),
117125 },
@@ -128,50 +136,98 @@ def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
128136 return xps .scalar_dtypes ()
129137
130138
139+ func_to_example_values : Dict [str , Dict [ParameterKind , Dict [str , Any ]]] = {
140+ "broadcast_to" : {
141+ Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([0 , 1 ])},
142+ Parameter .POSITIONAL_OR_KEYWORD : {"shape" : (1 , 2 )},
143+ },
144+ "cholesky" : {
145+ Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}
146+ },
147+ "inv" : {Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}},
148+ }
149+
150+
151+ def make_pretty_func (func_name : str , args : Sequence [Any ], kwargs : Dict [str , Any ]):
152+ f_sig = f"{ func_name } ("
153+ f_sig += ", " .join (str (a ) for a in args )
154+ if len (kwargs ) != 0 :
155+ if len (args ) != 0 :
156+ f_sig += ", "
157+ f_sig += ", " .join (f"{ k } ={ v } " for k , v in kwargs .items ())
158+ f_sig += ")"
159+ return f_sig
160+
161+
131162@given (data = st .data ())
132163def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature , data ):
133- if func_name in ["cholesky" , "inv" ]:
134- func (xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]]))
135- return
136- elif func_name == "solve" :
137- func (xp .asarray ([[1.0 , 2.0 ], [3.0 , 5.0 ]]), xp .asarray ([1.0 , 2.0 ]))
138- return
139-
140- pos_argname_to_example_value = {}
141- normal_argname_to_example_value = {}
142- kw_argname_to_example_value = {}
143- for stub_param in stub_sig .parameters .values ():
144- if stub_param .name in ["x" , "x1" ]:
164+ example_values : Dict [ParameterKind , Dict [str , Any ]] = func_to_example_values .get (
165+ func_name , {}
166+ )
167+ for kind in ALL_KINDS :
168+ example_values .setdefault (kind , {})
169+
170+ for param in stub_sig .parameters .values ():
171+ for name_to_value in example_values .values ():
172+ if param .name in name_to_value .keys ():
173+ continue
174+
175+ if param .default != Parameter .empty :
176+ example_value = param .default
177+ elif param .name in ["x" , "x1" ]:
145178 dtypes = get_dtypes_strategy (func_name )
146179 shapes = func_to_shapes [func_name ]
147180 example_value = data .draw (
148- xps .arrays (dtype = dtypes , shape = shapes ), label = stub_param .name
181+ xps .arrays (dtype = dtypes , shape = shapes ), label = param .name
149182 )
150- elif stub_param .name == "x2" :
151- assert "x1" in pos_argname_to_example_value .keys () # sanity check
152- x1 = pos_argname_to_example_value ["x1" ]
183+ elif param .name == "x2" :
184+ # sanity check
185+ assert "x1" in example_values [Parameter .POSITIONAL_ONLY ].keys ()
186+ x1 = example_values [Parameter .POSITIONAL_ONLY ]["x1" ]
153187 example_value = data .draw (
154188 xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = "x2"
155189 )
190+ elif param .name == "axes" :
191+ example_value = ()
192+ elif param .name == "shape" :
193+ example_value = ()
156194 else :
157- if stub_param .default != Parameter .empty :
158- example_value = stub_param .default
159- else :
160- pytest .skip (f"No example value for argument '{ stub_param .name } '" )
161-
162- if stub_param .kind == Parameter .POSITIONAL_ONLY :
163- pos_argname_to_example_value [stub_param .name ] = example_value
164- elif stub_param .kind == Parameter .POSITIONAL_OR_KEYWORD :
165- normal_argname_to_example_value [stub_param .name ] = example_value
166- elif stub_param .kind == Parameter .KEYWORD_ONLY :
167- kw_argname_to_example_value [stub_param .name ] = example_value
168- else :
169- pytest .skip ()
195+ pytest .skip (f"No example value for argument '{ param .name } '" )
170196
171- if len (normal_argname_to_example_value ) == 0 :
172- func (* pos_argname_to_example_value .values (), ** kw_argname_to_example_value )
197+ if param .kind in VAR_KINDS :
198+ pytest .skip ("TODO" )
199+ example_values [param .kind ][param .name ] = example_value
200+
201+ if len (example_values [Parameter .POSITIONAL_OR_KEYWORD ]) == 0 :
202+ f_func = make_pretty_func (
203+ func_name ,
204+ example_values [Parameter .POSITIONAL_ONLY ].values (),
205+ example_values [Parameter .KEYWORD_ONLY ],
206+ )
207+ note (f"trying { f_func } " )
208+ func (
209+ * example_values [Parameter .POSITIONAL_ONLY ].values (),
210+ ** example_values [Parameter .KEYWORD_ONLY ],
211+ )
173212 else :
174- pass # TODO
213+ either_argname_value_pairs = list (
214+ example_values [Parameter .POSITIONAL_OR_KEYWORD ].items ()
215+ )
216+ n_either_args = len (either_argname_value_pairs )
217+ for n_extra_args in reversed (range (n_either_args + 1 )):
218+ extra_args = [v for _ , v in either_argname_value_pairs [:n_extra_args ]]
219+ if n_extra_args < n_either_args :
220+ extra_kwargs = dict (either_argname_value_pairs [n_extra_args :])
221+ else :
222+ extra_kwargs = {}
223+ args = list (example_values [Parameter .POSITIONAL_ONLY ].values ())
224+ args += extra_args
225+ kwargs = copy (example_values [Parameter .KEYWORD_ONLY ])
226+ if len (extra_kwargs ) != 0 :
227+ kwargs .update (extra_kwargs )
228+ f_func = make_pretty_func (func_name , args , kwargs )
229+ note (f"trying { f_func } " )
230+ func (* args , ** kwargs )
175231
176232
177233def _test_func_signature (
0 commit comments