11import inspect
2+ from itertools import chain
23
34import pytest
45
56from ._array_module import mod , mod_name , ones , eye , float64 , bool , int64 , _UndefinedStub
67from .pytest_helpers import raises , doesnt_raise
78from . import dtype_helpers as dh
89
9- from . import function_stubs
1010from . import stubs
1111
1212
13- def stub_module (name ):
14- for m in stubs .extensions :
15- if name in getattr (function_stubs , m ).__all__ :
16- return m
13+ def extension_module (name ) -> bool :
14+ for funcs in stubs .extension_to_funcs .values ():
15+ for func in funcs :
16+ if name == func .__name__ :
17+ return True
18+ else :
19+ return False
1720
18- def extension_module (name ):
19- return name in stubs .extensions and name in function_stubs .__all__
2021
21- extension_module_names = []
22- for n in function_stubs .__all__ :
23- if extension_module (n ):
24- extension_module_names .extend ([f'{ n } .{ i } ' for i in getattr (function_stubs , n ).__all__ ])
22+ params = []
23+ for name in [f .__name__ for funcs in stubs .category_to_funcs .values () for f in funcs ]:
24+ if name in ["where" , "expand_dims" , "reshape" ]:
25+ params .append (pytest .param (name , marks = pytest .mark .skip (reason = "faulty test" )))
26+ else :
27+ params .append (name )
2528
2629
27- params = []
28- for name in function_stubs .__all__ :
29- marks = []
30- if extension_module (name ):
31- marks .append (pytest .mark .xp_extension (name ))
32- params .append (pytest .param (name , marks = marks ))
33- for name in extension_module_names :
34- ext = name .split ('.' )[0 ]
35- mark = pytest .mark .xp_extension (ext )
36- params .append (pytest .param (name , marks = [mark ]))
30+ for ext , name in [(ext , f .__name__ ) for ext , funcs in stubs .extension_to_funcs .items () for f in funcs ]:
31+ params .append (pytest .param (name , marks = pytest .mark .xp_extension (ext )))
3732
3833
39- def array_method (name ):
40- return stub_module ( name ) == 'array_object'
34+ def array_method (name ) -> bool :
35+ return name in [ f . __name__ for f in stubs . array_methods ]
4136
42- def function_category (name ):
43- return stub_module (name ).rsplit ('_' , 1 )[0 ].replace ('_' , ' ' )
37+ def function_category (name ) -> str :
38+ for category , funcs in chain (stubs .category_to_funcs .items (), stubs .extension_to_funcs .items ()):
39+ for func in funcs :
40+ if name == func .__name__ :
41+ return category
4442
4543def example_argument (arg , func_name , dtype ):
4644 """
@@ -138,7 +136,7 @@ def example_argument(arg, func_name, dtype):
138136 return ones ((3 ,), dtype = dtype )
139137 # Linear algebra functions tend to error if the input isn't "nice" as
140138 # a matrix
141- elif arg .startswith ('x' ) and func_name in function_stubs . linalg . __all__ :
139+ elif arg .startswith ('x' ) and func_name in [ f . __name__ for f in stubs . extension_to_funcs [ "linalg" ]] :
142140 return eye (3 )
143141 return known_args [arg ]
144142 else :
@@ -147,13 +145,15 @@ def example_argument(arg, func_name, dtype):
147145@pytest .mark .parametrize ('name' , params )
148146def test_has_names (name ):
149147 if extension_module (name ):
150- assert hasattr (mod , name ), f'{ mod_name } is missing the { name } extension'
151- elif '.' in name :
152- extension_mod , name = name .split ('.' )
153- assert hasattr (getattr (mod , extension_mod ), name ), f"{ mod_name } is missing the { function_category (name )} extension function { name } ()"
148+ ext = next (
149+ ext for ext , funcs in stubs .extension_to_funcs .items ()
150+ if name in [f .__name__ for f in funcs ]
151+ )
152+ ext_mod = getattr (mod , ext )
153+ assert hasattr (ext_mod , name ), f"{ mod_name } is missing the { function_category (name )} extension function { name } ()"
154154 elif array_method (name ):
155155 arr = ones ((1 , 1 ))
156- if getattr ( function_stubs . array_object , name ) is None :
156+ if name not in [ f . __name__ for f in stubs . array_methods ] :
157157 assert hasattr (arr , name ), f"The array object is missing the attribute { name } "
158158 else :
159159 assert hasattr (arr , name ), f"The array object is missing the method { name } ()"
@@ -192,14 +192,12 @@ def test_function_positional_args(name):
192192 _mod = ones ((), dtype = float64 )
193193 else :
194194 _mod = example_argument ('self' , name , dtype )
195- stub_func = getattr (function_stubs , name )
196195 elif '.' in name :
197196 extension_module_name , name = name .split ('.' )
198197 _mod = getattr (mod , extension_module_name )
199- stub_func = getattr (getattr (function_stubs , extension_module_name ), name )
200198 else :
201199 _mod = mod
202- stub_func = getattr ( function_stubs , name )
200+ stub_func = stubs . name_to_func [ name ]
203201
204202 if not hasattr (_mod , name ):
205203 pytest .skip (f"{ mod_name } does not have { name } (), skipping." )
@@ -245,14 +243,12 @@ def test_function_keyword_only_args(name):
245243
246244 if array_method (name ):
247245 _mod = ones ((1 , 1 ))
248- stub_func = getattr (function_stubs , name )
249246 elif '.' in name :
250247 extension_module_name , name = name .split ('.' )
251248 _mod = getattr (mod , extension_module_name )
252- stub_func = getattr (getattr (function_stubs , extension_module_name ), name )
253249 else :
254250 _mod = mod
255- stub_func = getattr ( function_stubs , name )
251+ stub_func = stubs . name_to_func [ name ]
256252
257253 if not hasattr (_mod , name ):
258254 pytest .skip (f"{ mod_name } does not have { name } (), skipping." )
0 commit comments