Skip to content

Commit 6dac8e4

Browse files
committed
Implement Lorentz broadening in SVD ad rule
1 parent 74e0441 commit 6dac8e4

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

varipeps/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ class VariPEPS_Config:
128128
triangular CTMRG.
129129
svd_sign_fix_eps (:obj:`float`):
130130
Value for numerical stability threshold in sign-fixed SVD.
131+
svd_ad_use_lorentz_broadening (:obj:`bool`):
132+
Enable Lorentz broadening in the AD rule for the SVD.
133+
svd_ad_lorentz_broadening_eps (:obj:`float`):
134+
Numerical stabilization constant in the Lorentz broadening in the
135+
AD rule for the SVD.
131136
optimizer_method (:obj:`Optimizing_Methods`):
132137
Method used for variational optimization of the PEPS network.
133138
optimizer_max_steps (:obj:`int`):
@@ -244,6 +249,8 @@ class VariPEPS_Config:
244249

245250
# SVD
246251
svd_sign_fix_eps: float = 1e-1
252+
svd_ad_use_lorentz_broadening: bool = False
253+
svd_ad_lorentz_broadening_eps: float = 1e-13
247254

248255
# Optimizer
249256
optimizer_method: Optimizing_Methods = Optimizing_Methods.BFGS

varipeps/utils/svd.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,16 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
7272
s_sums = s_dim + _T(s_dim)
7373
s_sums = jnp.where(s_sums > 0, s_sums, 1)
7474
s_diffs = s_dim - _T(s_dim)
75-
s_diffs = jnp.where(jnp.abs(s_diffs / s[0]) >= 1e-12, s_diffs, 0)
76-
s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
77-
s_diffs == 0.0
78-
) # is 1. where s_diffs is 0. and is 0. everywhere else
79-
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
80-
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
75+
76+
if varipeps_config.svd_ad_use_lorentz_broadening:
77+
F = s_diffs / (s_diffs**2 + varipeps_config.svd_ad_lorentz_broadening_eps)
78+
else:
79+
s_diffs = jnp.where(jnp.abs(s_diffs / s[0]) >= 1e-12, s_diffs, 0)
80+
s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
81+
s_diffs == 0.0
82+
) # is 1. where s_diffs is 0. and is 0. everywhere else
83+
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
84+
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
8185

8286
if only_u_or_vt is None or only_u_or_vt == "U":
8387
dSS = dS * (s_dim / s_sums).astype(A.dtype) # dS.dot(s_j / (s_i + s_j))

0 commit comments

Comments
 (0)