Skip to content

Commit d44d99c

Browse files
committed
Update dependencies and adapt code to new jax version
1 parent cbd9c43 commit d44d99c

File tree

8 files changed

+997
-1056
lines changed

8 files changed

+997
-1056
lines changed

peps_ad/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .__version__ import __version__
1616

17-
from jax.config import config as jax_config
17+
from jax import config as jax_config
1818

1919
jax_config.update("jax_enable_x64", True)
2020

peps_ad/ctmrg/projectors.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,16 @@ def calc_left_projectors(
349349
top_right,
350350
bottom_left,
351351
bottom_right,
352-
config.ctmrg_truncation_eps
353-
if state.ctmrg_effective_truncation_eps is None
354-
else state.ctmrg_effective_truncation_eps,
355-
config.ctmrg_full_projector_method
356-
if state.ctmrg_projector_method is None
357-
else state.ctmrg_projector_method,
352+
(
353+
config.ctmrg_truncation_eps
354+
if state.ctmrg_effective_truncation_eps is None
355+
else state.ctmrg_effective_truncation_eps
356+
),
357+
(
358+
config.ctmrg_full_projector_method
359+
if state.ctmrg_projector_method is None
360+
else state.ctmrg_projector_method
361+
),
358362
chi,
359363
)
360364

@@ -456,12 +460,16 @@ def calc_right_projectors(
456460
top_right,
457461
bottom_left,
458462
bottom_right,
459-
config.ctmrg_truncation_eps
460-
if state.ctmrg_effective_truncation_eps is None
461-
else state.ctmrg_effective_truncation_eps,
462-
config.ctmrg_full_projector_method
463-
if state.ctmrg_projector_method is None
464-
else state.ctmrg_projector_method,
463+
(
464+
config.ctmrg_truncation_eps
465+
if state.ctmrg_effective_truncation_eps is None
466+
else state.ctmrg_effective_truncation_eps
467+
),
468+
(
469+
config.ctmrg_full_projector_method
470+
if state.ctmrg_projector_method is None
471+
else state.ctmrg_projector_method
472+
),
465473
chi,
466474
)
467475

@@ -563,12 +571,16 @@ def calc_top_projectors(
563571
top_right,
564572
bottom_left,
565573
bottom_right,
566-
config.ctmrg_truncation_eps
567-
if state.ctmrg_effective_truncation_eps is None
568-
else state.ctmrg_effective_truncation_eps,
569-
config.ctmrg_full_projector_method
570-
if state.ctmrg_projector_method is None
571-
else state.ctmrg_projector_method,
574+
(
575+
config.ctmrg_truncation_eps
576+
if state.ctmrg_effective_truncation_eps is None
577+
else state.ctmrg_effective_truncation_eps
578+
),
579+
(
580+
config.ctmrg_full_projector_method
581+
if state.ctmrg_projector_method is None
582+
else state.ctmrg_projector_method
583+
),
572584
chi,
573585
)
574586

@@ -670,11 +682,15 @@ def calc_bottom_projectors(
670682
top_right,
671683
bottom_left,
672684
bottom_right,
673-
config.ctmrg_truncation_eps
674-
if state.ctmrg_effective_truncation_eps is None
675-
else state.ctmrg_effective_truncation_eps,
676-
config.ctmrg_full_projector_method
677-
if state.ctmrg_projector_method is None
678-
else state.ctmrg_projector_method,
685+
(
686+
config.ctmrg_truncation_eps
687+
if state.ctmrg_effective_truncation_eps is None
688+
else state.ctmrg_effective_truncation_eps
689+
),
690+
(
691+
config.ctmrg_full_projector_method
692+
if state.ctmrg_projector_method is None
693+
else state.ctmrg_projector_method
694+
),
679695
chi,
680696
)

peps_ad/ctmrg/routine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,11 @@ def calc_ctmrg_env(
413413
peps_tensors,
414414
working_unitcell,
415415
False,
416-
corner_singular_vals
417-
if corner_singular_vals is not None
418-
else init_corner_singular_vals,
416+
(
417+
corner_singular_vals
418+
if corner_singular_vals is not None
419+
else init_corner_singular_vals
420+
),
419421
eps,
420422
tmp_count,
421423
enforce_elementwise_convergence,

peps_ad/optimization/basinhopping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ class PEPS_AD_Basinhopping:
5050
expectation_func: Expectation_Model
5151
convert_to_unitcell_func: Optional[Map_To_PEPS_Model] = None
5252
autosave_filename: PathLike = "data/autosave.hdf5"
53-
autosave_func: Callable[
54-
[PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None
55-
] = autosave_function
53+
autosave_func: Callable[[PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None] = (
54+
autosave_function
55+
)
5656

5757
def __post_init__(self):
5858
if isinstance(self.initial_guess, PEPS_Unit_Cell):

peps_ad/optimization/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,9 @@ def random_noise(a):
300300
l_bfgs_grad_cache = deque(maxlen=peps_ad_config.optimizer_l_bfgs_maxlen + 1)
301301

302302
count = 0
303-
linesearch_step: Union[
304-
float, jnp.ndarray
305-
] = peps_ad_config.line_search_initial_step_size
303+
linesearch_step: Union[float, jnp.ndarray] = (
304+
peps_ad_config.line_search_initial_step_size
305+
)
306306
working_value: Union[float, jnp.ndarray]
307307
max_trunc_error_list = []
308308

peps_ad/peps/unitcell.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,10 @@ def load_from_group(
891891
if grp["config"].attrs.get(config_attr) is not None
892892
}
893893
if config_dict.get("ctmrg_full_projector_method"):
894-
config_dict[
895-
"ctmrg_full_projector_method"
896-
] = peps_ad.config.Projector_Method(
897-
config_dict["ctmrg_full_projector_method"]
894+
config_dict["ctmrg_full_projector_method"] = (
895+
peps_ad.config.Projector_Method(
896+
config_dict["ctmrg_full_projector_method"]
897+
)
898898
)
899899
if config_dict.get("optimizer_method"):
900900
config_dict["optimizer_method"] = peps_ad.config.Optimizing_Methods(

peps_ad/utils/svd.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import jax.numpy as jnp
44
from jax.lax import scan
5+
from jax.lax.linalg import svd as lax_svd
56
from jax import jit, custom_jvp, lax
67

8+
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
9+
710
from peps_ad import peps_ad_config
811

912
from typing import Tuple
@@ -19,18 +22,20 @@ def _H(x):
1922

2023
@custom_jvp
2124
def svd_wrapper(a):
22-
return jnp.linalg.svd(a, full_matrices=False, compute_uv=True)
25+
check_arraylike("jnp.linalg.svd", a)
26+
(a,) = promote_dtypes_inexact(jnp.asarray(a))
27+
return lax_svd(a, full_matrices=False, compute_uv=True)
2328

2429

2530
@svd_wrapper.defjvp
2631
def _svd_jvp_rule(primals, tangents):
2732
(A,) = primals
2833
(dA,) = tangents
29-
U, s, Vt = jnp.linalg.svd(A, full_matrices=False, compute_uv=True)
34+
U, s, Vt = lax_svd(A, full_matrices=False, compute_uv=True)
3035

3136
Ut, V = _H(U), _H(Vt)
3237
s_dim = s[..., None, :]
33-
dS = jnp.matmul(jnp.matmul(Ut, dA), V)
38+
dS = Ut @ dA @ V
3439
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
3540

3641
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
@@ -39,23 +44,23 @@ def _svd_jvp_rule(primals, tangents):
3944
) # is 1. where s_diffs is 0. and is 0. everywhere else
4045
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
4146
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
42-
dSS = s_dim * dS # dS.dot(jnp.diag(s))
43-
SdS = _T(s_dim) * dS # jnp.diag(s).dot(dS)
47+
dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s))
48+
SdS = _T(s_dim.astype(A.dtype)) * dS # jnp.diag(s).dot(dS)
4449

45-
s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.0)
50+
s_zeros = (s == 0).astype(s.dtype)
4651
s_inv = 1 / (s + s_zeros) - s_zeros
4752
s_inv_mat = jnp.vectorize(jnp.diag, signature="(k)->(k,k)")(s_inv)
48-
dUdV_diag = 0.5 * (dS - _H(dS)) * s_inv_mat
49-
dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag)
50-
dV = jnp.matmul(V, F * (SdS + _H(SdS)))
53+
dUdV_diag = 0.5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
54+
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
55+
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))
5156

5257
m, n = A.shape[-2:]
5358
if m > n:
54-
I = lax.expand_dims(jnp.eye(m, dtype=A.dtype), range(U.ndim - 2))
55-
dU = dU + jnp.matmul(I - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
59+
dAV = dA @ V
60+
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
5661
if n > m:
57-
I = lax.expand_dims(jnp.eye(n, dtype=A.dtype), range(V.ndim - 2))
58-
dV = dV + jnp.matmul(I - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
62+
dAHU = _H(dA) @ U
63+
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)
5964

6065
return (U, s, Vt), (dU, ds, _H(dV))
6166

0 commit comments

Comments
 (0)