@@ -20,22 +20,17 @@ def squeeze(x, /, axis):
2020 ...
2121
2222"""
23+ from collections import defaultdict
2324from inspect import Parameter , Signature , signature
2425from types import FunctionType
25- from typing import Any , Callable , Dict , List , Literal , get_args
26+ from typing import Any , Callable , Dict , Literal , get_args
27+ from warnings import warn
2628
2729import pytest
28- from hypothesis import given , note , settings
29- from hypothesis import strategies as st
30- from hypothesis .strategies import DataObject
3130
3231from . import dtype_helpers as dh
33- from . import hypothesis_helpers as hh
34- from . import xps
35- from ._array_module import _UndefinedStub
3632from ._array_module import mod as xp
37- from .stubs import array_methods , category_to_funcs , extension_to_funcs
38- from .typing import Array , DataType
33+ from .stubs import array_methods , category_to_funcs , extension_to_funcs , name_to_func
3934
4035pytestmark = pytest .mark .ci
4136
@@ -101,17 +96,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
10196 )
10297
10398
104- def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
105- if func_name in dh .func_in_dtypes .keys ():
106- dtypes = dh .func_in_dtypes [func_name ]
107- if hh .FILTER_UNDEFINED_DTYPES :
108- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
109- return st .sampled_from (dtypes )
110- else :
111- return xps .scalar_dtypes ()
112-
113-
114- def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ):
99+ def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ) -> str :
115100 f_sig = f"{ func_name } ("
116101 f_sig += ", " .join (str (a ) for a in args )
117102 if len (kwargs ) != 0 :
@@ -122,96 +107,161 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
122107 return f_sig
123108
124109
125- matrixy_funcs : List [FunctionType ] = [
126- * category_to_funcs ["linear_algebra" ],
127- * extension_to_funcs ["linalg" ],
110+ # We test uninspectable signatures by passing valid, manually-defined arguments
111+ # to the signature's function/method.
112+ #
113+ # Arguments which require use of the array module are specified as string
114+ # expressions to be eval()'d on runtime. This is as opposed to just using the
115+ # array module whilst setting up the tests, which is prone to halt the entire
116+ # test suite if an array module doesn't support a given expression.
117+ func_to_specified_args = defaultdict (
118+ dict ,
119+ {
120+ "permute_dims" : {"axes" : 0 },
121+ "reshape" : {"shape" : (1 , 5 )},
122+ "broadcast_to" : {"shape" : (1 , 5 )},
123+ "asarray" : {"obj" : [0 , 1 , 2 , 3 , 4 ]},
124+ "full_like" : {"fill_value" : 42 },
125+ "matrix_power" : {"n" : 2 },
126+ },
127+ )
128+ func_to_specified_arg_exprs = defaultdict (
129+ dict ,
130+ {
131+ "stack" : {"arrays" : "[xp.ones((5,)), xp.ones((5,))]" },
132+ "iinfo" : {"type" : "xp.int64" },
133+ "finfo" : {"type" : "xp.float64" },
134+ "logaddexp" : {a : "xp.ones((5,), dtype=xp.float64)" for a in ["x1" , "x2" ]},
135+ },
136+ )
137+ # We default most array arguments heuristically. As functions/methods work only
138+ # with arrays of certain dtypes and shapes, we specify only supported arrays
139+ # respective to the function.
140+ casty_names = ["__bool__" , "__int__" , "__float__" , "__complex__" , "__index__" ]
141+ matrixy_names = [
142+ f .__name__
143+ for f in category_to_funcs ["linear_algebra" ] + extension_to_funcs ["linalg" ]
128144]
129- matrixy_names : List [str ] = [f .__name__ for f in matrixy_funcs ]
130145matrixy_names += ["__matmul__" , "triu" , "tril" ]
146+ for func_name , func in name_to_func .items ():
147+ stub_sig = signature (func )
148+ array_argnames = set (stub_sig .parameters .keys ()) & {"x" , "x1" , "x2" , "other" }
149+ if func in array_methods :
150+ array_argnames .add ("self" )
151+ array_argnames -= set (func_to_specified_arg_exprs [func_name ].keys ())
152+ if len (array_argnames ) > 0 :
153+ in_dtypes = dh .func_in_dtypes [func_name ]
154+ for dtype_name in ["float64" , "bool" , "int64" , "complex128" ]:
155+ # We try float64 first because uninspectable numerical functions
156+ # tend to support float inputs first-and-foremost (i.e. PyTorch)
157+ try :
158+ dtype = getattr (xp , dtype_name )
159+ except AttributeError :
160+ pass
161+ else :
162+ if dtype in in_dtypes :
163+ if func_name in casty_names :
164+ shape = ()
165+ elif func_name in matrixy_names :
166+ shape = (2 , 2 )
167+ else :
168+ shape = (5 ,)
169+ fallback_array_expr = f"xp.ones({ shape } , dtype=xp.{ dtype_name } )"
170+ break
171+ else :
172+ warn (
173+ f"{ dh .func_in_dtypes ['{func_name}' ]} ={ in_dtypes } seemingly does "
174+ "not contain any assumed dtypes, so skipping specifying fallback array."
175+ )
176+ continue
177+ for argname in array_argnames :
178+ func_to_specified_arg_exprs [func_name ][argname ] = fallback_array_expr
179+
180+
181+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature ):
182+ if func_name in matrixy_names :
183+ pytest .xfail ("TODO" )
131184
185+ params = list (stub_sig .parameters .values ())
132186
133- @given (data = st .data ())
134- @settings (max_examples = 1 )
135- def _test_uninspectable_func (
136- func_name : str , func : Callable , stub_sig : Signature , array : Array , data : DataObject
137- ):
138- skip_msg = (
139- f"Signature for { func_name } () is not inspectable "
140- "and is too troublesome to test for otherwise"
187+ if len (params ) == 0 :
188+ func ()
189+ return
190+
191+ uninspectable_msg = (
192+ f"Note { func_name } () is not inspectable so arguments are passed "
193+ "manually to test the signature."
141194 )
142- if func_name in [
143- # 0d shapes
144- "__bool__" ,
145- "__int__" ,
146- "__index__" ,
147- "__float__" ,
148- # x2 elements must be >=0
149- "pow" ,
150- "bitwise_left_shift" ,
151- "bitwise_right_shift" ,
152- # axis default invalid with 0d shapes
153- "sort" ,
154- # shape requirements
155- * matrixy_names ,
156- ]:
157- pytest .skip (skip_msg )
158-
159- param_to_value : Dict [Parameter , Any ] = {}
160- for param in stub_sig .parameters .values ():
161- if param .kind in [Parameter .POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
195+
196+ argname_to_arg = func_to_specified_args [func_name ]
197+ argname_to_expr = func_to_specified_arg_exprs [func_name ]
198+ for argname , expr in argname_to_expr .items ():
199+ assert argname not in argname_to_arg .keys () # sanity check
200+ try :
201+ argname_to_arg [argname ] = eval (expr , {"xp" : xp })
202+ except Exception as e :
162203 pytest .skip (
163- skip_msg + f" (because '{ param .name } ' is a { kind_to_str [param .kind ]} )"
164- )
165- elif param .default != Parameter .empty :
166- value = param .default
167- elif param .name in ["x" , "x1" ]:
168- dtypes = get_dtypes_strategy (func_name )
169- value = data .draw (
170- xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = param .name
204+ f"Exception occured when evaluating { argname } ={ expr } : { e } \n "
205+ f"{ uninspectable_msg } "
171206 )
172- elif param .name in ["x2" , "other" ]:
173- if param .name == "x2" :
174- assert "x1" in [p .name for p in param_to_value .keys ()] # sanity check
175- orig = next (v for p , v in param_to_value .items () if p .name == "x1" )
207+
208+ posargs = []
209+ posorkw_args = {}
210+ kwargs = {}
211+ no_arg_msg = (
212+ "We have no argument specified for '{}'. Please ensure you're using "
213+ "the latest version of array-api-tests, then open an issue if one "
214+ f"doesn't already exist. { uninspectable_msg } "
215+ )
216+ for param in params :
217+ if param .kind == Parameter .POSITIONAL_ONLY :
218+ try :
219+ posargs .append (argname_to_arg [param .name ])
220+ except KeyError :
221+ pytest .skip (no_arg_msg .format (param .name ))
222+ elif param .kind == Parameter .POSITIONAL_OR_KEYWORD :
223+ if param .default == Parameter .empty :
224+ try :
225+ posorkw_args [param .name ] = argname_to_arg [param .name ]
226+ except KeyError :
227+ pytest .skip (no_arg_msg .format (param .name ))
176228 else :
177- assert array is not None # sanity check
178- orig = array
179- value = data . draw (
180- xps . arrays ( dtype = orig . dtype , shape = orig . shape ), label = param . name
181- )
229+ assert argname_to_arg [ param . name ]
230+ posorkw_args [ param . name ] = param . default
231+ elif param . kind == Parameter . KEYWORD_ONLY :
232+ assert param . default != Parameter . empty # sanity check
233+ kwargs [ param . name ] = param . default
182234 else :
183- pytest .skip (
184- skip_msg + f" (because no default was found for argument { param .name } )"
185- )
186- param_to_value [param ] = value
187-
188- args : List [Any ] = [
189- v for p , v in param_to_value .items () if p .kind == Parameter .POSITIONAL_ONLY
190- ]
191- kwargs : Dict [str , Any ] = {
192- p .name : v for p , v in param_to_value .items () if p .kind == Parameter .KEYWORD_ONLY
193- }
194- f_func = make_pretty_func (func_name , * args , ** kwargs )
195- note (f"trying { f_func } " )
196- func (* args , ** kwargs )
235+ assert param .kind in VAR_KINDS # sanity check
236+ pytest .skip (no_arg_msg .format (param .name ))
237+ if len (posorkw_args ) == 0 :
238+ func (* posargs , ** kwargs )
239+ else :
240+ func (* posargs , ** posorkw_args , ** kwargs )
241+ # TODO: test all positional and keyword permutations of pos-or-kw args
197242
198243
199- def _test_func_signature (func : Callable , stub : FunctionType , array = None ):
244+ def _test_func_signature (func : Callable , stub : FunctionType , is_method = False ):
200245 stub_sig = signature (stub )
201246 # If testing against array, ignore 'self' arg in stub as it won't be present
202247 # in func (which should be a method).
203- if array is not None :
248+ if is_method :
204249 stub_params = list (stub_sig .parameters .values ())
205- del stub_params [0 ]
250+ if stub_params [0 ].name == "self" :
251+ del stub_params [0 ]
206252 stub_sig = Signature (
207253 parameters = stub_params , return_annotation = stub_sig .return_annotation
208254 )
209255
210256 try :
211257 sig = signature (func )
212- _test_inspectable_func (sig , stub_sig )
213258 except ValueError :
214- _test_uninspectable_func (stub .__name__ , func , stub_sig , array )
259+ try :
260+ _test_uninspectable_func (stub .__name__ , func , stub_sig )
261+ except Exception as e :
262+ raise e from None # suppress parent exception for cleaner pytest output
263+ else :
264+ _test_inspectable_func (sig , stub_sig )
215265
216266
217267@pytest .mark .parametrize (
@@ -245,11 +295,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
245295
246296
247297@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
248- @given (st .data ())
249- @settings (max_examples = 1 )
250- def test_array_method_signature (stub : FunctionType , data : DataObject ):
251- dtypes = get_dtypes_strategy (stub .__name__ )
252- x = data .draw (xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = "x" )
298+ def test_array_method_signature (stub : FunctionType ):
299+ x_expr = func_to_specified_arg_exprs [stub .__name__ ]["self" ]
300+ try :
301+ x = eval (x_expr , {"xp" : xp })
302+ except Exception as e :
303+ pytest .skip (f"Exception occured when evaluating x={ x_expr } : { e } " )
253304 assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
254305 method = getattr (x , stub .__name__ )
255- _test_func_signature (method , stub , array = x )
306+ _test_func_signature (method , stub , is_method = True )
0 commit comments