@@ -174,10 +174,24 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
174174 return OnewayBroadcastableShapes (input_shape , result_shape )
175175
176176
177+ # Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
178+ # ARRAY_API_TESTS_SKIP_DTYPES
179+ all_dtypes = sampled_from (_sorted_dtypes )
180+ int_dtypes = sampled_from (dh .int_dtypes )
181+ uint_dtypes = sampled_from (dh .uint_dtypes )
182+ real_dtypes = sampled_from (dh .real_dtypes )
183+ # Warning: The hypothesis "floating_dtypes" is what we call
184+ # "real_floating_dtypes"
185+ floating_dtypes = sampled_from (dh .all_float_dtypes )
186+ real_floating_dtypes = sampled_from (dh .real_float_dtypes )
187+ numeric_dtypes = sampled_from (dh .numeric_dtypes )
188+ # Note: this always returns complex dtypes, even if api_version < 2022.12
189+ complex_dtypes = sampled_from (dh .complex_dtypes )
190+
177191def all_floating_dtypes () -> SearchStrategy [DataType ]:
178- strat = xps . floating_dtypes ()
192+ strat = floating_dtypes
179193 if api_version >= "2022.12" :
180- strat |= xps . complex_dtypes ()
194+ strat |= complex_dtypes
181195 return strat
182196
183197
@@ -236,7 +250,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
236250
237251@composite
238252def finite_matrices (draw , shape = matrix_shapes ()):
239- return draw (arrays (dtype = xps . floating_dtypes () ,
253+ return draw (arrays (dtype = floating_dtypes ,
240254 shape = shape ,
241255 elements = dict (allow_nan = False ,
242256 allow_infinity = False )))
@@ -245,7 +259,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
245259# Should we set a max_value here?
246260_rtol_float_kw = dict (allow_nan = False , allow_infinity = False , min_value = 0 )
247261rtols = one_of (floats (** _rtol_float_kw ),
248- arrays (dtype = xps . floating_dtypes () ,
262+ arrays (dtype = real_floating_dtypes ,
249263 shape = rtol_shared_matrix_shapes .map (lambda shape : shape [:- 2 ]),
250264 elements = _rtol_float_kw ))
251265
@@ -280,9 +294,9 @@ def mutually_broadcastable_shapes(
280294
281295two_mutually_broadcastable_shapes = mutually_broadcastable_shapes (2 )
282296
283- # Note: This should become hermitian_matrices when complex dtypes are added
297+ # TODO: Add support for complex Hermitian matrices
284298@composite
285- def symmetric_matrices (draw , dtypes = xps . floating_dtypes () , finite = True , bound = 10. ):
299+ def symmetric_matrices (draw , dtypes = real_floating_dtypes , finite = True , bound = 10. ):
286300 shape = draw (square_matrix_shapes )
287301 dtype = draw (dtypes )
288302 if not isinstance (finite , bool ):
@@ -297,7 +311,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10
297311 return H
298312
299313@composite
300- def positive_definite_matrices (draw , dtypes = xps . floating_dtypes () ):
314+ def positive_definite_matrices (draw , dtypes = floating_dtypes ):
301315 # For now just generate stacks of identity matrices
302316 # TODO: Generate arbitrary positive definite matrices, for instance, by
303317 # using something like
@@ -310,7 +324,7 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
310324 return broadcast_to (eye (n , dtype = dtype ), shape )
311325
312326@composite
313- def invertible_matrices (draw , dtypes = xps . floating_dtypes () , stack_shapes = shapes ()):
327+ def invertible_matrices (draw , dtypes = floating_dtypes , stack_shapes = shapes ()):
314328 # For now, just generate stacks of diagonal matrices.
315329 stack_shape = draw (stack_shapes )
316330 n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE // max (math .prod (stack_shape ), 1 )),)
@@ -344,7 +358,7 @@ def two_broadcastable_shapes(draw):
344358sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
345359
346360numeric_arrays = arrays (
347- dtype = shared (xps . floating_dtypes () , key = 'dtypes' ),
361+ dtype = shared (floating_dtypes , key = 'dtypes' ),
348362 shape = shared (xps .array_shapes (), key = 'shapes' ),
349363)
350364
@@ -388,7 +402,7 @@ def python_integer_indices(draw, sizes):
388402def integer_indices (draw , sizes ):
389403 # Return either a Python integer or a 0-D array with some integer dtype
390404 idx = draw (python_integer_indices (sizes ))
391- dtype = draw (xps . integer_dtypes () | xps . unsigned_integer_dtypes () )
405+ dtype = draw (int_dtypes | uint_dtypes )
392406 m , M = dh .dtype_ranges [dtype ]
393407 if m <= idx <= M :
394408 return draw (one_of (just (idx ),
0 commit comments