@@ -24,25 +24,33 @@ def _H(x):
2424 return jnp .conj (_T (x ))
2525
2626
27- @custom_jvp
28- def svd_wrapper (a ):
27+ @partial ( custom_jvp , nondiff_argnums = ( 1 ,))
28+ def svd_wrapper (a , use_qr = False ):
2929 check_arraylike ("jnp.linalg.svd" , a )
3030 (a ,) = promote_dtypes_inexact (jnp .asarray (a ))
3131
32- result = lax_svd (a , full_matrices = False , compute_uv = True )
33-
34- result = lax .cond (
35- jnp .isnan (jnp .sum (result [1 ])),
36- lambda matrix , _ : lax_svd (
37- matrix ,
32+ if use_qr :
33+ result = lax_svd (
34+ a ,
3835 full_matrices = False ,
3936 compute_uv = True ,
4037 algorithm = lax .linalg .SvdAlgorithm .QR ,
41- ),
42- lambda _ , res : res ,
43- a ,
44- result ,
45- )
38+ )
39+ else :
40+ result = lax_svd (a , full_matrices = False , compute_uv = True )
41+
42+ result = lax .cond (
43+ jnp .isnan (jnp .sum (result [1 ])),
44+ lambda matrix , _ : lax_svd (
45+ matrix ,
46+ full_matrices = False ,
47+ compute_uv = True ,
48+ algorithm = lax .linalg .SvdAlgorithm .QR ,
49+ ),
50+ lambda _ , res : res ,
51+ a ,
52+ result ,
53+ )
4654
4755 return result
4856
@@ -51,10 +59,10 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
5159 (A ,) = primals
5260 (dA ,) = tangents
5361
54- if use_qr :
62+ if use_qr and only_u_or_vt is not None :
5563 U , s , Vt = _svd_only_u_vt_impl (A , u_or_vt = 2 , use_qr = True )
5664 else :
57- U , s , Vt = svd_wrapper (A )
65+ U , s , Vt = svd_wrapper (A , use_qr = use_qr )
5866
5967 Ut , V = _H (U ), _H (Vt )
6068 s_dim = s [..., None , :]
@@ -106,8 +114,8 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
106114
107115
108116@svd_wrapper .defjvp
109- def _svd_jvp_rule (primals , tangents ):
110- return _svd_jvp_rule_impl (primals , tangents )
117+ def _svd_jvp_rule (use_qr , primals , tangents ):
118+ return _svd_jvp_rule_impl (primals , tangents , use_qr = use_qr )
111119
112120
113121jax .ffi .register_ffi_target (
@@ -293,10 +301,11 @@ def _svd_only_vt_jvp_rule(use_qr, primals, tangents):
293301 return _svd_jvp_rule_impl (primals , tangents , only_u_or_vt = "Vt" , use_qr = use_qr )
294302
295303
296- @partial (jit , inline = True , static_argnums = (1 ,))
304+ @partial (jit , inline = True , static_argnums = (1 , 2 ))
297305def gauge_fixed_svd (
298306 matrix : jnp .ndarray ,
299307 only_u_or_vh = None ,
308+ use_qr = False ,
300309) -> Tuple [jnp .ndarray , jnp .ndarray , jnp .ndarray ]:
301310 """
302311 Calculate the gauge-fixed (also called sign-fixed) SVD. To this end, each
@@ -316,13 +325,13 @@ def gauge_fixed_svd(
316325 Tuple with sign-fixed U, S and Vh of the SVD.
317326 """
318327 if only_u_or_vh is None :
319- U , S , Vh = svd_wrapper (matrix )
328+ U , S , Vh = svd_wrapper (matrix , use_qr = use_qr )
320329 gauge_unitary = U
321330 elif only_u_or_vh == "U" :
322- U , S = svd_only_u (matrix )
331+ U , S = svd_only_u (matrix , use_qr = use_qr )
323332 gauge_unitary = U
324333 elif only_u_or_vh == "Vh" :
325- S , Vh = svd_only_vt (matrix )
334+ S , Vh = svd_only_vt (matrix , use_qr = use_qr )
326335 gauge_unitary = Vh .T .conj ()
327336 else :
328337 raise ValueError ("Invalid value for parameter 'only_u_or_vh'." )
0 commit comments