22
33import jax .numpy as jnp
44from jax .lax import scan
5+ from jax .lax .linalg import svd as lax_svd
56from jax import jit , custom_jvp , lax
67
8+ from jax ._src .numpy .util import promote_dtypes_inexact , check_arraylike
9+
710from peps_ad import peps_ad_config
811
912from typing import Tuple
@@ -19,18 +22,20 @@ def _H(x):
1922
2023@custom_jvp
2124def svd_wrapper (a ):
22- return jnp .linalg .svd (a , full_matrices = False , compute_uv = True )
25+ check_arraylike ("jnp.linalg.svd" , a )
26+ (a ,) = promote_dtypes_inexact (jnp .asarray (a ))
27+ return lax_svd (a , full_matrices = False , compute_uv = True )
2328
2429
2530@svd_wrapper .defjvp
2631def _svd_jvp_rule (primals , tangents ):
2732 (A ,) = primals
2833 (dA ,) = tangents
29- U , s , Vt = jnp . linalg . svd (A , full_matrices = False , compute_uv = True )
34+ U , s , Vt = lax_svd (A , full_matrices = False , compute_uv = True )
3035
3136 Ut , V = _H (U ), _H (Vt )
3237 s_dim = s [..., None , :]
33- dS = jnp . matmul ( jnp . matmul ( Ut , dA ), V )
38+ dS = Ut @ dA @ V
3439 ds = jnp .real (jnp .diagonal (dS , 0 , - 2 , - 1 ))
3540
3641 s_diffs = (s_dim + _T (s_dim )) * (s_dim - _T (s_dim ))
@@ -39,23 +44,23 @@ def _svd_jvp_rule(primals, tangents):
3944 ) # is 1. where s_diffs is 0. and is 0. everywhere else
4045 s_diffs_zeros = lax .expand_dims (s_diffs_zeros , range (s_diffs .ndim - 2 ))
4146 F = 1 / (s_diffs + s_diffs_zeros ) - s_diffs_zeros
42- dSS = s_dim * dS # dS.dot(jnp.diag(s))
43- SdS = _T (s_dim ) * dS # jnp.diag(s).dot(dS)
47+ dSS = s_dim . astype ( A . dtype ) * dS # dS.dot(jnp.diag(s))
48+ SdS = _T (s_dim . astype ( A . dtype ) ) * dS # jnp.diag(s).dot(dS)
4449
45- s_zeros = jnp . ones ((), dtype = A . dtype ) * ( s == 0.0 )
50+ s_zeros = ( s == 0 ). astype ( s . dtype )
4651 s_inv = 1 / (s + s_zeros ) - s_zeros
4752 s_inv_mat = jnp .vectorize (jnp .diag , signature = "(k)->(k,k)" )(s_inv )
48- dUdV_diag = 0.5 * (dS - _H (dS )) * s_inv_mat
49- dU = jnp . matmul ( U , F * (dSS + _H (dSS )) + dUdV_diag )
50- dV = jnp . matmul ( V , F * (SdS + _H (SdS )))
53+ dUdV_diag = 0.5 * (dS - _H (dS )) * s_inv_mat . astype ( A . dtype )
54+ dU = U @ ( F . astype ( A . dtype ) * (dSS + _H (dSS )) + dUdV_diag )
55+ dV = V @ ( F . astype ( A . dtype ) * (SdS + _H (SdS )))
5156
5257 m , n = A .shape [- 2 :]
5358 if m > n :
54- I = lax . expand_dims ( jnp . eye ( m , dtype = A . dtype ), range ( U . ndim - 2 ))
55- dU = dU + jnp . matmul ( I - jnp . matmul ( U , Ut ), jnp . matmul ( dA , V )) / s_dim
59+ dAV = dA @ V
60+ dU = dU + ( dAV - U @ ( Ut @ dAV )) / s_dim . astype ( A . dtype )
5661 if n > m :
57- I = lax . expand_dims ( jnp . eye ( n , dtype = A . dtype ), range ( V . ndim - 2 ))
58- dV = dV + jnp . matmul ( I - jnp . matmul ( V , Vt ), jnp . matmul ( _H ( dA ), U )) / s_dim
62+ dAHU = _H ( dA ) @ U
63+ dV = dV + ( dAHU - V @ ( Vt @ dAHU )) / s_dim . astype ( A . dtype )
5964
6065 return (U , s , Vt ), (dU , ds , _H (dV ))
6166
0 commit comments