@@ -20,22 +20,18 @@ def squeeze(x, /, axis):
2020 ...
2121
2222"""
23+ from collections import defaultdict
24+ from copy import copy
2325from inspect import Parameter , Signature , signature
2426from types import FunctionType
25- from typing import Any , Callable , Dict , List , Literal , get_args
27+ from typing import Any , Callable , Dict , Literal , get_args
28+ from warnings import warn
2629
2730import pytest
28- from hypothesis import given , note , settings
29- from hypothesis import strategies as st
30- from hypothesis .strategies import DataObject
3131
3232from . import dtype_helpers as dh
33- from . import hypothesis_helpers as hh
34- from . import xps
35- from ._array_module import _UndefinedStub
3633from ._array_module import mod as xp
37- from .stubs import array_methods , category_to_funcs , extension_to_funcs
38- from .typing import Array , DataType
34+ from .stubs import array_methods , category_to_funcs , extension_to_funcs , name_to_func
3935
4036pytestmark = pytest .mark .ci
4137
@@ -93,24 +89,15 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
9389 stub_param .name in sig .parameters .keys ()
9490 ), f"Argument '{ stub_param .name } ' missing from signature"
9591 param = next (p for p in params if p .name == stub_param .name )
92+ f_stub_kind = kind_to_str [stub_param .kind ]
9693 assert param .kind in [stub_param .kind , Parameter .POSITIONAL_OR_KEYWORD ,], (
9794 f"{ param .name } is a { kind_to_str [param .kind ]} , "
9895 f"but should be a { f_stub_kind } "
9996 f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
10097 )
10198
10299
103- def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
104- if func_name in dh .func_in_dtypes .keys ():
105- dtypes = dh .func_in_dtypes [func_name ]
106- if hh .FILTER_UNDEFINED_DTYPES :
107- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
108- return st .sampled_from (dtypes )
109- else :
110- return xps .scalar_dtypes ()
111-
112-
113- def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ):
100+ def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ) -> str :
114101 f_sig = f"{ func_name } ("
115102 f_sig += ", " .join (str (a ) for a in args )
116103 if len (kwargs ) != 0 :
@@ -121,96 +108,165 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
121108 return f_sig
122109
123110
124- matrixy_funcs : List [FunctionType ] = [
125- * category_to_funcs ["linear_algebra" ],
126- * extension_to_funcs ["linalg" ],
111+ # We test uninspectable signatures by passing valid, manually-defined arguments
112+ # to the signature's function/method.
113+ #
114+ # Arguments which require use of the array module are specified as string
115+ # expressions to be eval()'d on runtime. This is as opposed to just using the
116+ # array module whilst setting up the tests, which is prone to halt the entire
117+ # test suite if an array module doesn't support a given expression.
118+ func_to_specified_args = defaultdict (
119+ dict ,
120+ {
121+ "permute_dims" : {"axes" : 0 },
122+ "reshape" : {"shape" : (1 , 5 )},
123+ "broadcast_to" : {"shape" : (1 , 5 )},
124+ "asarray" : {"obj" : [0 , 1 , 2 , 3 , 4 ]},
125+ "full_like" : {"fill_value" : 42 },
126+ "matrix_power" : {"n" : 2 },
127+ },
128+ )
129+ func_to_specified_arg_exprs = defaultdict (
130+ dict ,
131+ {
132+ "stack" : {"arrays" : "[xp.ones((5,)), xp.ones((5,))]" },
133+ "iinfo" : {"type" : "xp.int64" },
134+ "finfo" : {"type" : "xp.float64" },
135+ "cholesky" : {"x" : "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)" },
136+ "inv" : {"x" : "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" },
137+ "solve" : {
138+ a : "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1" , "x2" ]
139+ },
140+ },
141+ )
142+ # We default most array arguments heuristically. As functions/methods work only
143+ # with arrays of certain dtypes and shapes, we specify only supported arrays
144+ # respective to the function.
145+ casty_names = ["__bool__" , "__int__" , "__float__" , "__complex__" , "__index__" ]
146+ matrixy_names = [
147+ f .__name__
148+ for f in category_to_funcs ["linear_algebra" ] + extension_to_funcs ["linalg" ]
127149]
128- matrixy_names : List [str ] = [f .__name__ for f in matrixy_funcs ]
129150matrixy_names += ["__matmul__" , "triu" , "tril" ]
151+ for func_name , func in name_to_func .items ():
152+ stub_sig = signature (func )
153+ array_argnames = set (stub_sig .parameters .keys ()) & {"x" , "x1" , "x2" , "other" }
154+ if func in array_methods :
155+ array_argnames .add ("self" )
156+ array_argnames -= set (func_to_specified_arg_exprs [func_name ].keys ())
157+ if len (array_argnames ) > 0 :
158+ in_dtypes = dh .func_in_dtypes [func_name ]
159+ for dtype_name in ["float64" , "bool" , "int64" , "complex128" ]:
160+ # We try float64 first because uninspectable numerical functions
161+ # tend to support float inputs first-and-foremost (i.e. PyTorch)
162+ try :
163+ dtype = getattr (xp , dtype_name )
164+ except AttributeError :
165+ pass
166+ else :
167+ if dtype in in_dtypes :
168+ if func_name in casty_names :
169+ shape = ()
170+ elif func_name in matrixy_names :
171+ shape = (3 , 3 )
172+ else :
173+ shape = (5 ,)
174+ fallback_array_expr = f"xp.ones({ shape } , dtype=xp.{ dtype_name } )"
175+ break
176+ else :
177+ warn (
178+ f"{ dh .func_in_dtypes ['{func_name}' ]} ={ in_dtypes } seemingly does "
179+ "not contain any assumed dtypes, so skipping specifying fallback array."
180+ )
181+ continue
182+ for argname in array_argnames :
183+ func_to_specified_arg_exprs [func_name ][argname ] = fallback_array_expr
184+
130185
186+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature ):
187+ params = list (stub_sig .parameters .values ())
131188
132- @given (data = st .data ())
133- @settings (max_examples = 1 )
134- def _test_uninspectable_func (
135- func_name : str , func : Callable , stub_sig : Signature , array : Array , data : DataObject
136- ):
137- skip_msg = (
138- f"Signature for { func_name } () is not inspectable "
139- "and is too troublesome to test for otherwise"
189+ if len (params ) == 0 :
190+ func ()
191+ return
192+
193+ uninspectable_msg = (
194+ f"Note { func_name } () is not inspectable so arguments are passed "
195+ "manually to test the signature."
140196 )
141- if func_name in [
142- # 0d shapes
143- "__bool__" ,
144- "__int__" ,
145- "__index__" ,
146- "__float__" ,
147- # x2 elements must be >=0
148- "pow" ,
149- "bitwise_left_shift" ,
150- "bitwise_right_shift" ,
151- # axis default invalid with 0d shapes
152- "sort" ,
153- # shape requirements
154- * matrixy_names ,
155- ]:
156- pytest .skip (skip_msg )
157-
158- param_to_value : Dict [Parameter , Any ] = {}
159- for param in stub_sig .parameters .values ():
160- if param .kind in [Parameter .POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
197+
198+ argname_to_arg = copy (func_to_specified_args [func_name ])
199+ argname_to_expr = func_to_specified_arg_exprs [func_name ]
200+ for argname , expr in argname_to_expr .items ():
201+ assert argname not in argname_to_arg .keys () # sanity check
202+ try :
203+ argname_to_arg [argname ] = eval (expr , {"xp" : xp })
204+ except Exception as e :
161205 pytest .skip (
162- skip_msg + f" (because '{ param .name } ' is a { kind_to_str [param .kind ]} )"
163- )
164- elif param .default != Parameter .empty :
165- value = param .default
166- elif param .name in ["x" , "x1" ]:
167- dtypes = get_dtypes_strategy (func_name )
168- value = data .draw (
169- xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = param .name
206+ f"Exception occured when evaluating { argname } ={ expr } : { e } \n "
207+ f"{ uninspectable_msg } "
170208 )
171- elif param .name in ["x2" , "other" ]:
172- if param .name == "x2" :
173- assert "x1" in [p .name for p in param_to_value .keys ()] # sanity check
174- orig = next (v for p , v in param_to_value .items () if p .name == "x1" )
209+
210+ posargs = []
211+ posorkw_args = {}
212+ kwargs = {}
213+ no_arg_msg = (
214+ "We have no argument specified for '{}'. Please ensure you're using "
215+ "the latest version of array-api-tests, then open an issue if one "
216+ f"doesn't already exist. { uninspectable_msg } "
217+ )
218+ for param in params :
219+ if param .kind == Parameter .POSITIONAL_ONLY :
220+ try :
221+ posargs .append (argname_to_arg [param .name ])
222+ except KeyError :
223+ pytest .skip (no_arg_msg .format (param .name ))
224+ elif param .kind == Parameter .POSITIONAL_OR_KEYWORD :
225+ if param .default == Parameter .empty :
226+ try :
227+ posorkw_args [param .name ] = argname_to_arg [param .name ]
228+ except KeyError :
229+ pytest .skip (no_arg_msg .format (param .name ))
175230 else :
176- assert array is not None # sanity check
177- orig = array
178- value = data . draw (
179- xps . arrays ( dtype = orig . dtype , shape = orig . shape ), label = param . name
180- )
231+ assert argname_to_arg [ param . name ]
232+ posorkw_args [ param . name ] = param . default
233+ elif param . kind == Parameter . KEYWORD_ONLY :
234+ assert param . default != Parameter . empty # sanity check
235+ kwargs [ param . name ] = param . default
181236 else :
182- pytest .skip (
183- skip_msg + f" (because no default was found for argument { param .name } )"
184- )
185- param_to_value [param ] = value
186-
187- args : List [Any ] = [
188- v for p , v in param_to_value .items () if p .kind == Parameter .POSITIONAL_ONLY
189- ]
190- kwargs : Dict [str , Any ] = {
191- p .name : v for p , v in param_to_value .items () if p .kind == Parameter .KEYWORD_ONLY
192- }
193- f_func = make_pretty_func (func_name , * args , ** kwargs )
194- note (f"trying { f_func } " )
195- func (* args , ** kwargs )
237+ assert param .kind in VAR_KINDS # sanity check
238+ pytest .skip (no_arg_msg .format (param .name ))
239+ if len (posorkw_args ) == 0 :
240+ func (* posargs , ** kwargs )
241+ else :
242+ posorkw_name_to_arg_pairs = list (posorkw_args .items ())
243+ for i in range (len (posorkw_name_to_arg_pairs ), - 1 , - 1 ):
244+ extra_posargs = [arg for _ , arg in posorkw_name_to_arg_pairs [:i ]]
245+ extra_kwargs = dict (posorkw_name_to_arg_pairs [i :])
246+ func (* posargs , * extra_posargs , ** kwargs , ** extra_kwargs )
196247
197248
198- def _test_func_signature (func : Callable , stub : FunctionType , array = None ):
249+ def _test_func_signature (func : Callable , stub : FunctionType , is_method = False ):
199250 stub_sig = signature (stub )
200251 # If testing against array, ignore 'self' arg in stub as it won't be present
201252 # in func (which should be a method).
202- if array is not None :
253+ if is_method :
203254 stub_params = list (stub_sig .parameters .values ())
204- del stub_params [0 ]
255+ if stub_params [0 ].name == "self" :
256+ del stub_params [0 ]
205257 stub_sig = Signature (
206258 parameters = stub_params , return_annotation = stub_sig .return_annotation
207259 )
208260
209261 try :
210262 sig = signature (func )
211- _test_inspectable_func (sig , stub_sig )
212263 except ValueError :
213- _test_uninspectable_func (stub .__name__ , func , stub_sig , array )
264+ try :
265+ _test_uninspectable_func (stub .__name__ , func , stub_sig )
266+ except Exception as e :
267+ raise e from None # suppress parent exception for cleaner pytest output
268+ else :
269+ _test_inspectable_func (sig , stub_sig )
214270
215271
216272@pytest .mark .parametrize (
@@ -244,11 +300,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
244300
245301
246302@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
247- @given (st .data ())
248- @settings (max_examples = 1 )
249- def test_array_method_signature (stub : FunctionType , data : DataObject ):
250- dtypes = get_dtypes_strategy (stub .__name__ )
251- x = data .draw (xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = "x" )
303+ def test_array_method_signature (stub : FunctionType ):
304+ x_expr = func_to_specified_arg_exprs [stub .__name__ ]["self" ]
305+ try :
306+ x = eval (x_expr , {"xp" : xp })
307+ except Exception as e :
308+ pytest .skip (f"Exception occured when evaluating x={ x_expr } : { e } " )
252309 assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
253310 method = getattr (x , stub .__name__ )
254- _test_func_signature (method , stub , array = x )
311+ _test_func_signature (method , stub , is_method = True )
0 commit comments