11"""
22https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33"""
4- import math
54from collections import defaultdict
6- from typing import Tuple , Union , List
5+ from typing import List , Tuple , Union
76
87import pytest
98from hypothesis import assume , given , reject
1413from . import hypothesis_helpers as hh
1514from . import pytest_helpers as ph
1615from . import xps
17- from .typing import DataType , ScalarType , Param
1816from .function_stubs import elementwise_functions
19-
17+ from . typing import DataType , Param , ScalarType
2018
2119# TODO: move tests not covering elementwise funcs/ops into standalone tests
2220# result_type, meshgrid, tensordor, vecdot
@@ -28,29 +26,6 @@ def test_result_type(dtypes):
2826 ph .assert_dtype ("result_type" , dtypes , out , repr_name = "out" )
2927
3028
31- # The number and size of generated arrays is arbitrarily limited to prevent
32- # meshgrid() running out of memory.
33- @given (
34- dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
35- data = st .data (),
36- )
37- def test_meshgrid (dtypes , data ):
38- arrays = []
39- shapes = data .draw (
40- hh .mutually_broadcastable_shapes (
41- len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
42- ),
43- label = "shapes" ,
44- )
45- for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
46- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f"x{ i } " )
47- arrays .append (x )
48- assert math .prod (x .size for x in arrays ) <= hh .MAX_ARRAY_SIZE # sanity check
49- out = xp .meshgrid (* arrays )
50- for i , x in enumerate (out ):
51- ph .assert_dtype ("meshgrid" , dtypes , x .dtype , repr_name = f"out[{ i } ].dtype" )
52-
53-
5429bitwise_shift_funcs = [
5530 "bitwise_left_shift" ,
5631 "bitwise_right_shift" ,
@@ -78,6 +53,14 @@ def make_id(
7853 return f"{ func_name } ({ f_args } ) -> { f_out_dtype } "
7954
8055
56+ def mark_stubbed_dtypes (* dtypes ):
57+ for dtype in dtypes :
58+ if isinstance (dtype , xp ._UndefinedStub ):
59+ return pytest .mark .skip (reason = f"xp.{ dtype .name } not defined" )
60+ else :
61+ return ()
62+
63+
8164func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
8265for func_name in elementwise_functions .__all__ :
8366 valid_in_dtypes = dh .func_in_dtypes [func_name ]
@@ -90,6 +73,7 @@ def make_id(
9073 (in_dtype ,),
9174 out_dtype ,
9275 id = make_id (func_name , (in_dtype ,), out_dtype ),
76+ marks = mark_stubbed_dtypes (in_dtype , out_dtype ),
9377 )
9478 func_params .append (p )
9579 elif ndtypes == 2 :
@@ -103,6 +87,7 @@ def make_id(
10387 (in_dtype1 , in_dtype2 ),
10488 out_dtype ,
10589 id = make_id (func_name , (in_dtype1 , in_dtype2 ), out_dtype ),
90+ marks = mark_stubbed_dtypes (in_dtype1 , in_dtype2 , out_dtype ),
10691 )
10792 func_params .append (p )
10893 else :
@@ -143,6 +128,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
143128 (dtype1 , dtype2 ),
144129 promoted_dtype ,
145130 id = make_id ("" , (dtype1 , dtype2 ), promoted_dtype ),
131+ marks = mark_stubbed_dtypes (dtype1 , dtype2 , promoted_dtype ),
146132 )
147133 promotion_params .append (p )
148134
@@ -194,6 +180,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
194180 (in_dtype ,),
195181 out_dtype ,
196182 id = make_id (op , (in_dtype ,), out_dtype ),
183+ marks = mark_stubbed_dtypes (in_dtype , out_dtype ),
197184 )
198185 op_params .append (p )
199186 else :
@@ -206,6 +193,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
206193 (in_dtype1 , in_dtype2 ),
207194 out_dtype ,
208195 id = make_id (op , (in_dtype1 , in_dtype2 ), out_dtype ),
196+ marks = mark_stubbed_dtypes (in_dtype1 , in_dtype2 , out_dtype ),
209197 )
210198 op_params .append (p )
211199# We generate params for abs seperately as it does not have an associated symbol
@@ -216,6 +204,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
216204 (in_dtype ,),
217205 in_dtype ,
218206 id = make_id ("__abs__" , (in_dtype ,), in_dtype ),
207+ marks = mark_stubbed_dtypes (in_dtype ),
219208 )
220209 op_params .append (p )
221210
@@ -263,6 +252,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
263252 (in_dtype1 , in_dtype2 ),
264253 promoted_dtype ,
265254 id = make_id (op , (in_dtype1 , in_dtype2 ), promoted_dtype ),
255+ marks = mark_stubbed_dtypes (in_dtype1 , in_dtype2 , promoted_dtype ),
266256 )
267257 inplace_params .append (p )
268258
@@ -301,6 +291,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
301291 in_stype ,
302292 out_dtype ,
303293 id = make_id (op , (in_dtype , in_stype ), out_dtype ),
294+ marks = mark_stubbed_dtypes (in_dtype , out_dtype ),
304295 )
305296 op_scalar_params .append (p )
306297
@@ -333,6 +324,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
333324 dtype ,
334325 in_stype ,
335326 id = make_id (op , (dtype , in_stype ), dtype ),
327+ marks = mark_stubbed_dtypes (dtype ),
336328 )
337329 inplace_scalar_params .append (p )
338330
0 commit comments