We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c3ad006 commit 2a26413Copy full SHA for 2a26413
varipeps/utils/svd.py
@@ -87,7 +87,7 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
87
elif only_u_or_vt == "U":
88
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
89
elif only_u_or_vt == "Vt":
90
- dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)) + 0.5 * dUdV_diag)
+ dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)) + dUdV_diag)
91
92
m, n = A.shape[-2:]
93
if m > n and (only_u_or_vt is None or only_u_or_vt == "U"):
0 commit comments