@@ -147,9 +147,13 @@ def test_prod(x, data):
147147 _dtype = dh .default_float
148148 else :
149149 _dtype = dtype
150- # We ignore asserting the out dtype if what we expect is undefined
151- # See https://github.com/data-apis/array-api-tests/issues/106
152- if not isinstance (_dtype , _UndefinedStub ):
150+ if isinstance (_dtype , _UndefinedStub ):
151+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
152+ # uint32 or uint64), we skip testing the output dtype.
153+ # See https://github.com/data-apis/array-api-tests/issues/106
154+ if _dtype in dh .uint_dtypes :
155+ assert dh .is_int_dtype (out .dtype ) # sanity check
156+ else :
153157 ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
154158 _axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
155159 ph .assert_keepdimable_shape (
@@ -248,7 +252,14 @@ def test_sum(x, data):
248252 _dtype = dh .default_float
249253 else :
250254 _dtype = dtype
251- ph .assert_dtype ("sum" , x .dtype , out .dtype , _dtype )
255+ if isinstance (_dtype , _UndefinedStub ):
256+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
257+ # uint32 or uint64), we skip testing the output dtype.
258+ # See https://github.com/data-apis/array-api-tests/issues/160
259+ if _dtype in dh .uint_dtypes :
260+ assert dh .is_int_dtype (out .dtype ) # sanity check
261+ else :
262+ ph .assert_dtype ("sum" , x .dtype , out .dtype , _dtype )
252263 _axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
253264 ph .assert_keepdimable_shape (
254265 "sum" , x .shape , out .shape , _axes , kw .get ("keepdims" , False ), ** kw
0 commit comments