File tree Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Original file line number Diff line number Diff line change 22
33from tqdm_loggable .auto import tqdm
44
5+ import jax
56import jax .numpy as jnp
67from jax import jit
78from jax .flatten_util import ravel_pytree
@@ -1119,6 +1120,9 @@ def line_search(
11191120
11201121 count += 1
11211122
1123+ if new_unitcell [0 , 0 ][0 ][0 ].chi != unitcell [0 , 0 ][0 ][0 ].chi :
1124+ jax .clear_caches ()
1125+
11221126 if count == varipeps_config .line_search_max_steps :
11231127 raise NoSuitableStepSizeError (f"Count { count } , Last alpha { alpha } " )
11241128
Original file line number Diff line number Diff line change 88
99from tqdm_loggable .auto import tqdm
1010
11+ import jax
1112from jax import jit
1213import jax .numpy as jnp
1314from jax .lax import scan
@@ -405,6 +406,7 @@ def random_noise(a):
405406 while count < varipeps_config .optimizer_max_steps :
406407 runtime_start = time .perf_counter ()
407408
409+ chi_before_ctmrg = working_unitcell [0 , 0 ][0 ][0 ].chi
408410 try :
409411 if varipeps_config .ad_use_custom_vjp :
410412 (
@@ -498,6 +500,9 @@ def random_noise(a):
498500
499501 continue
500502
503+ if working_unitcell [0 , 0 ][0 ][0 ].chi != chi_before_ctmrg :
504+ jax .clear_caches ()
505+
501506 working_gradient = [elem .conj () for elem in working_gradient_seq ]
502507
503508 if signal_reset_descent_dir :
You can’t perform that action at this time.
0 commit comments