|
1 | 1 | import re |
2 | | -import itertools |
3 | 2 | from contextlib import contextmanager |
4 | 3 | from functools import reduce, wraps |
5 | 4 | import math |
@@ -309,18 +308,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes( |
309 | 308 | # For now, just generate stacks of diagonal matrices. |
310 | 309 | n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),) |
311 | 310 | stack_shape = draw(stack_shapes) |
312 | | - shape = stack_shape + (n, n) |
313 | | - d = draw(arrays(dtypes, shape=n*prod(stack_shape), |
| 311 | + d = draw(arrays(dtypes, shape=(*stack_shape, 1, n), |
314 | 312 | elements=dict(allow_nan=False, allow_infinity=False))) |
315 | 313 | # Functions that require invertible matrices may do anything when it is |
316 | 314 | # singular, including raising an exception, so we make sure the diagonals |
317 | 315 | # are sufficiently nonzero to avoid any numerical issues. |
318 | 316 | assume(xp.all(xp.abs(d) > 0.5)) |
319 | | - |
320 | | - a = xp.zeros(shape) |
321 | | - for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): |
322 | | - a[idx + (i, i)] = d[j] |
323 | | - return a |
| 317 | + diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1)) |
| 318 | + return xp.where(diag_mask, d, xp.zeros_like(d)) |
324 | 319 |
|
325 | 320 | # TODO: Better name |
326 | 321 | @composite |
|
0 commit comments