@@ -130,44 +130,15 @@ def test_prod(x, data):
130130 out = xp .prod (x , ** kw )
131131
132132 dtype = kw .get ("dtype" , None )
133- if dtype is None :
134- if dh .is_int_dtype (x .dtype ):
135- if x .dtype in dh .uint_dtypes :
136- default_dtype = dh .default_uint
137- else :
138- default_dtype = dh .default_int
139- if default_dtype is None :
140- _dtype = None
141- else :
142- m , M = dh .dtype_ranges [x .dtype ]
143- d_m , d_M = dh .dtype_ranges [default_dtype ]
144- if m < d_m or M > d_M :
145- _dtype = x .dtype
146- else :
147- _dtype = default_dtype
148- elif dh .is_float_dtype (x .dtype , include_complex = False ):
149- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
150- _dtype = x .dtype
151- else :
152- _dtype = dh .default_float
153- elif api_version > "2021.12" :
154- # Complex dtype
155- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_complex ]:
156- _dtype = x .dtype
157- else :
158- _dtype = dh .default_complex
159- else :
160- raise RuntimeError ("Unexpected dtype. This indicates a bug in the test suite." )
161- else :
162- _dtype = dtype
163- if _dtype is None :
133+ expected_dtype = dh .accumulation_result_dtype (x .dtype , dtype )
134+ if expected_dtype is None :
164135 # If a default uint cannot exist (i.e. in PyTorch which doesn't support
165136 # uint32 or uint64), we skip testing the output dtype.
166137 # See https://github.com/data-apis/array-api-tests/issues/106
167138 if x .dtype in dh .uint_dtypes :
168139 assert dh .is_int_dtype (out .dtype ) # sanity check
169140 else :
170- ph .assert_dtype ("prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
141+ ph .assert_dtype ("prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = expected_dtype )
171142 _axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
172143 ph .assert_keepdimable_shape (
173144 "prod" , in_shape = x .shape , out_shape = out .shape , axes = _axes , keepdims = keepdims , kw = kw
@@ -246,44 +217,15 @@ def test_sum(x, data):
246217 out = xp .sum (x , ** kw )
247218
248219 dtype = kw .get ("dtype" , None )
249- if dtype is None :
250- if dh .is_int_dtype (x .dtype ):
251- if x .dtype in dh .uint_dtypes :
252- default_dtype = dh .default_uint
253- else :
254- default_dtype = dh .default_int
255- if default_dtype is None :
256- _dtype = None
257- else :
258- m , M = dh .dtype_ranges [x .dtype ]
259- d_m , d_M = dh .dtype_ranges [default_dtype ]
260- if m < d_m or M > d_M :
261- _dtype = x .dtype
262- else :
263- _dtype = default_dtype
264- elif dh .is_float_dtype (x .dtype , include_complex = False ):
265- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
266- _dtype = x .dtype
267- else :
268- _dtype = dh .default_float
269- elif api_version > "2021.12" :
270- # Complex dtype
271- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_complex ]:
272- _dtype = x .dtype
273- else :
274- _dtype = dh .default_complex
275- else :
276- raise RuntimeError ("Unexpected dtype. This indicates a bug in the test suite." )
277- else :
278- _dtype = dtype
279- if _dtype is None :
220+ expected_dtype = dh .accumulation_result_dtype (x .dtype , dtype )
221+ if expected_dtype is None :
280222 # If a default uint cannot exist (i.e. in PyTorch which doesn't support
281223 # uint32 or uint64), we skip testing the output dtype.
282224 # See https://github.com/data-apis/array-api-tests/issues/160
283225 if x .dtype in dh .uint_dtypes :
284226 assert dh .is_int_dtype (out .dtype ) # sanity check
285227 else :
286- ph .assert_dtype ("sum" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
228+ ph .assert_dtype ("sum" , in_dtype = x .dtype , out_dtype = out .dtype , expected = expected_dtype )
287229 _axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
288230 ph .assert_keepdimable_shape (
289231 "sum" , in_shape = x .shape , out_shape = out .shape , axes = _axes , keepdims = keepdims , kw = kw
0 commit comments