11import re
22import itertools
33from contextlib import contextmanager
4- from functools import reduce
5- from math import sqrt
4+ from functools import reduce , wraps
5+ import math
66from operator import mul
7- from typing import Any , List , NamedTuple , Optional , Sequence , Tuple , Union
7+ import struct
8+ from typing import Any , List , Mapping , NamedTuple , Optional , Sequence , Tuple , Union
89
910from hypothesis import assume , reject
1011from hypothesis .strategies import (SearchStrategy , booleans , composite , floats ,
1112 integers , just , lists , none , one_of ,
12- sampled_from , shared )
13+ sampled_from , shared , builds )
1314
1415from . import _array_module as xp , api_version
1516from . import dtype_helpers as dh
2021from ._array_module import broadcast_to , eye , float32 , float64 , full
2122from .stubs import category_to_funcs
2223from .pytest_helpers import nargs
23- from .typing import Array , DataType , Shape
24-
25- # Set this to True to not fail tests just because a dtype isn't implemented.
26- # If no compatible dtype is implemented for a given test, the test will fail
27- # with a hypothesis health check error. Note that this functionality will not
28- # work for floating point dtypes as those are assumed to be defined in other
29- # places in the tests.
30- FILTER_UNDEFINED_DTYPES = True
31- # TODO: currently we assume this to be true - we probably can remove this completely
32- assert FILTER_UNDEFINED_DTYPES
33-
34- integer_dtypes = xps .integer_dtypes () | xps .unsigned_integer_dtypes ()
35- floating_dtypes = xps .floating_dtypes ()
36- numeric_dtypes = xps .numeric_dtypes ()
37- integer_or_boolean_dtypes = xps .boolean_dtypes () | integer_dtypes
38- boolean_dtypes = xps .boolean_dtypes ()
39- dtypes = xps .scalar_dtypes ()
40-
41- shared_dtypes = shared (dtypes , key = "dtype" )
42- shared_floating_dtypes = shared (floating_dtypes , key = "dtype" )
24+ from .typing import Array , DataType , Scalar , Shape
25+
26+
27+ def _float32ify (n : Union [int , float ]) -> float :
28+ n = float (n )
29+ return struct .unpack ("!f" , struct .pack ("!f" , n ))[0 ]
30+
31+
32+ @wraps (xps .from_dtype )
33+ def from_dtype (dtype , ** kwargs ) -> SearchStrategy [Scalar ]:
34+ """xps.from_dtype() without the crazy large numbers."""
35+ if dtype == xp .bool :
36+ return xps .from_dtype (dtype , ** kwargs )
37+
38+ if dtype in dh .complex_dtypes :
39+ component_dtype = dh .dtype_components [dtype ]
40+ else :
41+ component_dtype = dtype
42+
43+ min_ , max_ = dh .dtype_ranges [component_dtype ]
44+
45+ if "min_value" not in kwargs .keys () and min_ != 0 :
46+ assert min_ < 0 # sanity check
47+ min_value = - 1 * math .floor (math .sqrt (abs (min_ )))
48+ if component_dtype == xp .float32 :
49+ min_value = _float32ify (min_value )
50+ kwargs ["min_value" ] = min_value
51+ if "max_value" not in kwargs .keys ():
52+ assert max_ > 0 # sanity check
53+ max_value = math .floor (math .sqrt (max_ ))
54+ if component_dtype == xp .float32 :
55+ max_value = _float32ify (max_value )
56+ kwargs ["max_value" ] = max_value
57+
58+ if dtype in dh .complex_dtypes :
59+ component_strat = xps .from_dtype (dh .dtype_components [dtype ], ** kwargs )
60+ return builds (complex , component_strat , component_strat )
61+ else :
62+ return xps .from_dtype (dtype , ** kwargs )
63+
64+
65+ @wraps (xps .arrays )
66+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
67+ """xps.arrays() without the crazy large numbers."""
68+ if isinstance (dtype , SearchStrategy ):
69+ return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
70+
71+ if elements is None :
72+ elements = from_dtype (dtype )
73+ elif isinstance (elements , Mapping ):
74+ elements = from_dtype (dtype , ** elements )
75+
76+ return xps .arrays (dtype , * args , elements = elements , ** kwargs )
77+
4378
4479_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
4580_sorted_dtypes = [d for category in _dtype_categories for d in category ]
@@ -62,21 +97,19 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
6297 return key
6398
6499_promotable_dtypes = list (dh .promotion_table .keys ())
65- if FILTER_UNDEFINED_DTYPES :
66- _promotable_dtypes = [
67- (d1 , d2 ) for d1 , d2 in _promotable_dtypes
68- if not isinstance (d1 , _UndefinedStub ) or not isinstance (d2 , _UndefinedStub )
69- ]
100+ _promotable_dtypes = [
101+ (d1 , d2 ) for d1 , d2 in _promotable_dtypes
102+ if not isinstance (d1 , _UndefinedStub ) or not isinstance (d2 , _UndefinedStub )
103+ ]
70104promotable_dtypes : List [Tuple [DataType , DataType ]] = sorted (_promotable_dtypes , key = _dtypes_sorter )
71105
72106def mutually_promotable_dtypes (
73107 max_size : Optional [int ] = 2 ,
74108 * ,
75109 dtypes : Sequence [DataType ] = dh .all_dtypes ,
76110) -> SearchStrategy [Tuple [DataType , ...]]:
77- if FILTER_UNDEFINED_DTYPES :
78- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
79- assert len (dtypes ) > 0 , "all dtypes undefined" # sanity check
111+ dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
112+ assert len (dtypes ) > 0 , "all dtypes undefined" # sanity check
80113 if max_size == 2 :
81114 return sampled_from (
82115 [(i , j ) for i , j in promotable_dtypes if i in dtypes and j in dtypes ]
@@ -166,7 +199,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
166199# Limit the total size of an array shape
167200MAX_ARRAY_SIZE = 10000
168201# Size to use for 2-dim arrays
169- SQRT_MAX_ARRAY_SIZE = int (sqrt (MAX_ARRAY_SIZE ))
202+ SQRT_MAX_ARRAY_SIZE = int (math . sqrt (MAX_ARRAY_SIZE ))
170203
171204# np.prod and others have overflow and math.prod is Python 3.8+ only
172205def prod (seq ):
@@ -202,7 +235,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
202235
203236@composite
204237def finite_matrices (draw , shape = matrix_shapes ()):
205- return draw (xps . arrays (dtype = xps .floating_dtypes (),
238+ return draw (arrays (dtype = xps .floating_dtypes (),
206239 shape = shape ,
207240 elements = dict (allow_nan = False ,
208241 allow_infinity = False )))
@@ -211,7 +244,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
211244# Should we set a max_value here?
212245_rtol_float_kw = dict (allow_nan = False , allow_infinity = False , min_value = 0 )
213246rtols = one_of (floats (** _rtol_float_kw ),
214- xps . arrays (dtype = xps .floating_dtypes (),
247+ arrays (dtype = xps .floating_dtypes (),
215248 shape = rtol_shared_matrix_shapes .map (lambda shape : shape [:- 2 ]),
216249 elements = _rtol_float_kw ))
217250
@@ -254,7 +287,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
254287 if not isinstance (finite , bool ):
255288 finite = draw (finite )
256289 elements = {'allow_nan' : False , 'allow_infinity' : False } if finite else None
257- a = draw (xps . arrays (dtype = dtype , shape = shape , elements = elements ))
290+ a = draw (arrays (dtype = dtype , shape = shape , elements = elements ))
258291 upper = xp .triu (a )
259292 lower = xp .triu (a , k = 1 ).mT
260293 return upper + lower
@@ -277,7 +310,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
277310 n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
278311 stack_shape = draw (stack_shapes )
279312 shape = stack_shape + (n , n )
280- d = draw (xps . arrays (dtypes , shape = n * prod (stack_shape ),
313+ d = draw (arrays (dtypes , shape = n * prod (stack_shape ),
281314 elements = dict (allow_nan = False , allow_infinity = False )))
282315 # Functions that require invertible matrices may do anything when it is
283316 # singular, including raising an exception, so we make sure the diagonals
@@ -303,7 +336,7 @@ def two_broadcastable_shapes(draw):
303336sizes = integers (0 , MAX_ARRAY_SIZE )
304337sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
305338
306- numeric_arrays = xps . arrays (
339+ numeric_arrays = arrays (
307340 dtype = shared (xps .floating_dtypes (), key = 'dtypes' ),
308341 shape = shared (xps .array_shapes (), key = 'shapes' ),
309342)
@@ -348,7 +381,7 @@ def python_integer_indices(draw, sizes):
348381def integer_indices (draw , sizes ):
349382 # Return either a Python integer or a 0-D array with some integer dtype
350383 idx = draw (python_integer_indices (sizes ))
351- dtype = draw (integer_dtypes )
384+ dtype = draw (xps . integer_dtypes () | xps . unsigned_integer_dtypes () )
352385 m , M = dh .dtype_ranges [dtype ]
353386 if m <= idx <= M :
354387 return draw (one_of (just (idx ),
@@ -424,16 +457,15 @@ def two_mutual_arrays(
424457) -> Tuple [SearchStrategy [Array ], SearchStrategy [Array ]]:
425458 if not isinstance (dtypes , Sequence ):
426459 raise TypeError (f"{ dtypes = } not a sequence" )
427- if FILTER_UNDEFINED_DTYPES :
428- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
429- assert len (dtypes ) > 0 # sanity check
460+ dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
461+ assert len (dtypes ) > 0 # sanity check
430462 mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
431463 mutual_shapes = shared (two_shapes )
432- arrays1 = xps . arrays (
464+ arrays1 = arrays (
433465 dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
434466 shape = mutual_shapes .map (lambda pair : pair [0 ]),
435467 )
436- arrays2 = xps . arrays (
468+ arrays2 = arrays (
437469 dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
438470 shape = mutual_shapes .map (lambda pair : pair [1 ]),
439471 )
0 commit comments