@@ -136,24 +136,27 @@ def test_prod(x, data):
136136 default_dtype = dh .default_uint
137137 else :
138138 default_dtype = dh .default_int
139- m , M = dh .dtype_ranges [x .dtype ]
140- d_m , d_M = dh .dtype_ranges [default_dtype ]
141- if m < d_m or M > d_M :
142- _dtype = x .dtype
139+ if default_dtype is None :
140+ _dtype = None
143141 else :
144- _dtype = default_dtype
142+ m , M = dh .dtype_ranges [x .dtype ]
143+ d_m , d_M = dh .dtype_ranges [default_dtype ]
144+ if m < d_m or M > d_M :
145+ _dtype = x .dtype
146+ else :
147+ _dtype = default_dtype
145148 else :
146149 if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
147150 _dtype = x .dtype
148151 else :
149152 _dtype = dh .default_float
150153 else :
151154 _dtype = dtype
152- if isinstance ( _dtype , _UndefinedStub ) :
155+ if _dtype is None :
153156 # If a default uint cannot exist (i.e. in PyTorch which doesn't support
154157 # uint32 or uint64), we skip testing the output dtype.
155158 # See https://github.com/data-apis/array-api-tests/issues/106
156- if _dtype in dh .uint_dtypes :
159+ if x . dtype in dh .uint_dtypes :
157160 assert dh .is_int_dtype (out .dtype ) # sanity check
158161 else :
159162 ph .assert_dtype ("prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
@@ -241,24 +244,27 @@ def test_sum(x, data):
241244 default_dtype = dh .default_uint
242245 else :
243246 default_dtype = dh .default_int
244- m , M = dh .dtype_ranges [x .dtype ]
245- d_m , d_M = dh .dtype_ranges [default_dtype ]
246- if m < d_m or M > d_M :
247- _dtype = x .dtype
247+ if default_dtype is None :
248+ _dtype = None
248249 else :
249- _dtype = default_dtype
250+ m , M = dh .dtype_ranges [x .dtype ]
251+ d_m , d_M = dh .dtype_ranges [default_dtype ]
252+ if m < d_m or M > d_M :
253+ _dtype = x .dtype
254+ else :
255+ _dtype = default_dtype
250256 else :
251257 if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
252258 _dtype = x .dtype
253259 else :
254260 _dtype = dh .default_float
255261 else :
256262 _dtype = dtype
257- if isinstance ( _dtype , _UndefinedStub ) :
263+ if _dtype is None :
258264 # If a default uint cannot exist (i.e. in PyTorch which doesn't support
259265 # uint32 or uint64), we skip testing the output dtype.
260266 # See https://github.com/data-apis/array-api-tests/issues/160
261- if _dtype in dh .uint_dtypes :
267+ if x . dtype in dh .uint_dtypes :
262268 assert dh .is_int_dtype (out .dtype ) # sanity check
263269 else :
264270 ph .assert_dtype ("sum" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
0 commit comments