Skip to content

Commit 3ed908b

Browse files
committed
Allow svd(compute_uv=False) in numba
1 parent 429ba6c commit 3ed908b

File tree

2 files changed

+7
-18
lines changed

2 files changed

+7
-18
lines changed

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,20 @@
2525
def numba_funcify_SVD(op, node, **kwargs):
2626
full_matrices = op.full_matrices
2727
compute_uv = op.compute_uv
28+
out_dtype = np.dtype(node.outputs[0].dtype)
2829

29-
if not compute_uv:
30-
31-
warnings.warn(
32-
(
33-
"Numba will use object mode to allow the "
34-
"`compute_uv` argument to `numpy.linalg.svd`."
35-
),
36-
UserWarning,
37-
)
30+
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
3831

39-
ret_sig = get_numba_type(node.outputs[0].type)
32+
if not compute_uv:
4033

41-
@numba_basic.numba_njit
34+
@numba_basic.numba_njit()
4235
def svd(x):
43-
with numba.objmode(ret=ret_sig):
44-
ret = np.linalg.svd(x, full_matrices, compute_uv)
36+
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
4537
return ret
4638

4739
else:
4840

49-
out_dtype = node.outputs[0].type.numpy_dtype
50-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
51-
52-
@numba_basic.numba_njit(inline="always")
41+
@numba_basic.numba_njit()
5342
def svd(x):
5443
return np.linalg.svd(inputs_cast(x), full_matrices)
5544

tests/link/numba/test_nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def test_QRFull(x, mode, exc):
477477
),
478478
True,
479479
False,
480-
UserWarning,
480+
None,
481481
),
482482
],
483483
)

0 commit comments

Comments
 (0)