@@ -313,12 +313,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
313313 # For now, just generate stacks of diagonal matrices.
314314 n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
315315 stack_shape = draw (stack_shapes )
316- d = draw (arrays (dtypes , shape = (* stack_shape , 1 , n ),
317- elements = dict (allow_nan = False , allow_infinity = False )))
316+ dtype = draw (dtypes )
317+ elements = one_of (
318+ from_dtype (dtype , min_value = 0.5 , allow_nan = False , allow_infinity = False ),
319+ from_dtype (dtype , max_value = - 0.5 , allow_nan = False , allow_infinity = False ),
320+ )
321+ d = draw (arrays (dtype , shape = (* stack_shape , 1 , n ), elements = elements ))
322+
318323 # Functions that require invertible matrices may do anything when it is
319324 # singular, including raising an exception, so we make sure the diagonals
320325 # are sufficiently nonzero to avoid any numerical issues.
321- assume (xp .all (xp .abs (d ) > 0.5 ))
326+ assert xp .all (xp .abs (d ) >= 0.5 )
327+
322328 diag_mask = xp .arange (n ) == xp .reshape (xp .arange (n ), (n , 1 ))
323329 return xp .where (diag_mask , d , xp .zeros_like (d ))
324330
0 commit comments