11import re
22import itertools
33from contextlib import contextmanager
4- from functools import reduce
5- from math import sqrt
4+ from functools import reduce , wraps
5+ import math
66from operator import mul
7- from typing import Any , List , NamedTuple , Optional , Sequence , Tuple , Union
7+ import struct
8+ from typing import Any , List , Mapping , NamedTuple , Optional , Sequence , Tuple , Union
89
910from hypothesis import assume , reject
1011from hypothesis .strategies import (SearchStrategy , booleans , composite , floats ,
1112 integers , just , lists , none , one_of ,
12- sampled_from , shared )
13+ sampled_from , shared , builds )
1314
1415from . import _array_module as xp , api_version
1516from . import dtype_helpers as dh
2021from ._array_module import broadcast_to , eye , float32 , float64 , full
2122from .stubs import category_to_funcs
2223from .pytest_helpers import nargs
23- from .typing import Array , DataType , Shape
24+ from .typing import Array , DataType , Scalar , Shape
25+
26+
27+ def _float32ify (n : Union [int , float ]) -> float :
28+ n = float (n )
29+ return struct .unpack ("!f" , struct .pack ("!f" , n ))[0 ]
30+
31+
32+ @wraps (xps .from_dtype )
33+ def from_dtype (dtype , ** kwargs ) -> SearchStrategy [Scalar ]:
34+ """xps.from_dtype() without the crazy large numbers."""
35+ if dtype == xp .bool :
36+ return xps .from_dtype (dtype , ** kwargs )
37+
38+ if dtype in dh .complex_dtypes :
39+ component_dtype = dh .dtype_components [dtype ]
40+ else :
41+ component_dtype = dtype
42+
43+ min_ , max_ = dh .dtype_ranges [component_dtype ]
44+
45+ if "min_value" not in kwargs .keys () and min_ != 0 :
46+ assert min_ < 0 # sanity check
47+ min_value = - 1 * math .floor (math .sqrt (abs (min_ )))
48+ if component_dtype == xp .float32 :
49+ min_value = _float32ify (min_value )
50+ kwargs ["min_value" ] = min_value
51+ if "max_value" not in kwargs .keys ():
52+ assert max_ > 0 # sanity check
53+ max_value = math .floor (math .sqrt (max_ ))
54+ if component_dtype == xp .float32 :
55+ max_value = _float32ify (max_value )
56+ kwargs ["max_value" ] = max_value
57+
58+ if dtype in dh .complex_dtypes :
59+ component_strat = xps .from_dtype (dh .dtype_components [dtype ], ** kwargs )
60+ return builds (complex , component_strat , component_strat )
61+ else :
62+ return xps .from_dtype (dtype , ** kwargs )
63+
64+
65+ @wraps (xps .arrays )
66+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
67+ """xps.arrays() without the crazy large numbers."""
68+ if isinstance (dtype , SearchStrategy ):
69+ return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
70+
71+ if elements is None :
72+ elements = from_dtype (dtype )
73+ elif isinstance (elements , Mapping ):
74+ elements = from_dtype (dtype , ** elements )
75+
76+ return xps .arrays (dtype , * args , elements = elements , ** kwargs )
77+
2478
2579_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
2680_sorted_dtypes = [d for category in _dtype_categories for d in category ]
@@ -145,7 +199,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
145199# Limit the total size of an array shape
146200MAX_ARRAY_SIZE = 10000
147201# Size to use for 2-dim arrays
148- SQRT_MAX_ARRAY_SIZE = int (sqrt (MAX_ARRAY_SIZE ))
202+ SQRT_MAX_ARRAY_SIZE = int (math . sqrt (MAX_ARRAY_SIZE ))
149203
150204# np.prod and others have overflow and math.prod is Python 3.8+ only
151205def prod (seq ):
@@ -181,7 +235,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
181235
182236@composite
183237def finite_matrices (draw , shape = matrix_shapes ()):
184- return draw (xps . arrays (dtype = xps .floating_dtypes (),
238+ return draw (arrays (dtype = xps .floating_dtypes (),
185239 shape = shape ,
186240 elements = dict (allow_nan = False ,
187241 allow_infinity = False )))
@@ -190,7 +244,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
190244# Should we set a max_value here?
191245_rtol_float_kw = dict (allow_nan = False , allow_infinity = False , min_value = 0 )
192246rtols = one_of (floats (** _rtol_float_kw ),
193- xps . arrays (dtype = xps .floating_dtypes (),
247+ arrays (dtype = xps .floating_dtypes (),
194248 shape = rtol_shared_matrix_shapes .map (lambda shape : shape [:- 2 ]),
195249 elements = _rtol_float_kw ))
196250
@@ -233,7 +287,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
233287 if not isinstance (finite , bool ):
234288 finite = draw (finite )
235289 elements = {'allow_nan' : False , 'allow_infinity' : False } if finite else None
236- a = draw (xps . arrays (dtype = dtype , shape = shape , elements = elements ))
290+ a = draw (arrays (dtype = dtype , shape = shape , elements = elements ))
237291 upper = xp .triu (a )
238292 lower = xp .triu (a , k = 1 ).mT
239293 return upper + lower
@@ -256,7 +310,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
256310 n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
257311 stack_shape = draw (stack_shapes )
258312 shape = stack_shape + (n , n )
259- d = draw (xps . arrays (dtypes , shape = n * prod (stack_shape ),
313+ d = draw (arrays (dtypes , shape = n * prod (stack_shape ),
260314 elements = dict (allow_nan = False , allow_infinity = False )))
261315 # Functions that require invertible matrices may do anything when it is
262316 # singular, including raising an exception, so we make sure the diagonals
@@ -282,7 +336,7 @@ def two_broadcastable_shapes(draw):
282336sizes = integers (0 , MAX_ARRAY_SIZE )
283337sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
284338
285- numeric_arrays = xps . arrays (
339+ numeric_arrays = arrays (
286340 dtype = shared (xps .floating_dtypes (), key = 'dtypes' ),
287341 shape = shared (xps .array_shapes (), key = 'shapes' ),
288342)
@@ -407,11 +461,11 @@ def two_mutual_arrays(
407461 assert len (dtypes ) > 0 # sanity check
408462 mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
409463 mutual_shapes = shared (two_shapes )
410- arrays1 = xps . arrays (
464+ arrays1 = arrays (
411465 dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
412466 shape = mutual_shapes .map (lambda pair : pair [0 ]),
413467 )
414- arrays2 = xps . arrays (
468+ arrays2 = arrays (
415469 dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
416470 shape = mutual_shapes .map (lambda pair : pair [1 ]),
417471 )
0 commit comments