@@ -57,22 +57,33 @@ def _svd_jvp_rule(primals, tangents):
5757 dS = Ut @ dA @ V
5858 ds = jnp .real (jnp .diagonal (dS , 0 , - 2 , - 1 ))
5959
60- s_diffs = (s_dim + _T (s_dim )) * (s_dim - _T (s_dim ))
61- # s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
60+ s_sums = s_dim + _T (s_dim )
61+ s_diffs = s_dim - _T (s_dim )
62+ s_diffs = jnp .where (jnp .abs (s_diffs / s [0 ]) >= 1e-12 , s_diffs , 0 )
6263 s_diffs_zeros = jnp .ones ((), dtype = A .dtype ) * (
6364 s_diffs == 0.0
6465 ) # is 1. where s_diffs is 0. and is 0. everywhere else
6566 s_diffs_zeros = lax .expand_dims (s_diffs_zeros , range (s_diffs .ndim - 2 ))
6667 F = 1 / (s_diffs + s_diffs_zeros ) - s_diffs_zeros
67- dSS = s_dim .astype (A .dtype ) * dS # dS.dot(jnp.diag(s))
68- SdS = _T (s_dim .astype (A .dtype )) * dS # jnp.diag(s).dot(dS)
68+ dSS = dS * (s_dim / s_sums ).astype (A .dtype ) # dS.dot(s_j / (s_i + s_j))
69+ SdS = (_T (s_dim ) / s_sums ).astype (A .dtype ) * dS # (s_i / (s_i + s_j)).dot(dS)
70+
71+ # s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
72+ # # s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
73+ # s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
74+ # s_diffs == 0.0
75+ # ) # is 1. where s_diffs is 0. and is 0. everywhere else
76+ # s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
77+ # F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
78+ # dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s))
79+ # SdS = _T(s_dim.astype(A.dtype)) * dS # jnp.diag(s).dot(dS)
6980
7081 s_zeros = (s == 0 ).astype (s .dtype )
7182 s_inv = 1 / (s + s_zeros ) - s_zeros
7283 s_inv_mat = jnp .vectorize (jnp .diag , signature = "(k)->(k,k)" )(s_inv )
7384 dUdV_diag = 0.5 * (dS - _H (dS )) * s_inv_mat .astype (A .dtype )
74- dU = U @ (F .astype (A .dtype ) * (dSS + _H (dSS )) + dUdV_diag )
75- dV = V @ (F .astype (A .dtype ) * (SdS + _H (SdS )))
85+ dU = U @ (F .astype (A .dtype ) * (dSS + _H (dSS )) + 0.5 * dUdV_diag )
86+ dV = V @ (F .astype (A .dtype ) * (SdS + _H (SdS )) + 0.5 * dUdV_diag )
7687
7788 m , n = A .shape [- 2 :]
7889 if m > n :
0 commit comments