Skip to content

Commit af8498a

Browse files
Added logging to structure factor ctmrg
1 parent 2075f33 commit af8498a

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

varipeps/ctmrg/structure_factor_routine.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from jax import jit, custom_vjp, vjp, tree_util
66
from jax.lax import cond, while_loop
77
import jax.debug as jdebug
8+
import logging
9+
import time
10+
11+
logger = logging.getLogger("varipeps.ctmrg")
812

913
from varipeps import varipeps_config, varipeps_global_state
1014
from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell
@@ -125,8 +129,8 @@ def _ctmrg_body_func_structure_factor(carry):
125129
measure = jnp.linalg.norm(corner_svd - last_corner_svd)
126130
converged = measure < eps
127131

128-
if config.ctmrg_print_steps:
129-
debug_print("CTMRG: {}: {}", count, measure)
132+
if logger.isEnabledFor(logging.DEBUG):
133+
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
130134
if config.ctmrg_verbose_output:
131135
for ti, ctm_enum_i, diff in verbose_data:
132136
debug_print(
@@ -244,6 +248,7 @@ def calc_ctmrg_env_structure_factor(
244248
norm_smallest_S = jnp.nan
245249
already_tried_chi = {working_unitcell[0, 0][0][0].chi}
246250

251+
t0 = time.perf_counter()
247252
while True:
248253
tmp_count = 0
249254
corner_singular_vals = None
@@ -304,6 +309,17 @@ def calc_ctmrg_env_structure_factor(
304309
)
305310
)
306311

312+
if not converged and logger.isEnabledFor(logging.WARNING):
313+
logger.warning(
314+
"CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
315+
time.perf_counter() - t0, end_count, norm_smallest_S
316+
)
317+
elif logger.isEnabledFor(logging.INFO):
318+
logger.info(
319+
"CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
320+
time.perf_counter() - t0, end_count, norm_smallest_S
321+
)
322+
307323
current_truncation_eps = (
308324
varipeps_config.ctmrg_truncation_eps
309325
if varipeps_global_state.ctmrg_effective_truncation_eps is None
@@ -326,15 +342,14 @@ def calc_ctmrg_env_structure_factor(
326342
working_unitcell = working_unitcell.change_chi(new_chi)
327343
initial_unitcell = initial_unitcell.change_chi(new_chi)
328344

329-
if varipeps_config.ctmrg_print_steps:
330-
debug_print(
331-
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
345+
if logger.isEnabledFor(logging.INFO):
346+
logger.info(
347+
"CTMRG (SF): Increasing chi to %d since smallest SVD Norm was %.3e.",
332348
new_chi,
333349
norm_smallest_S,
334350
)
335351

336352
already_tried_chi.add(new_chi)
337-
338353
continue
339354
elif (
340355
varipeps_config.ctmrg_heuristic_decrease_chi
@@ -351,15 +366,14 @@ def calc_ctmrg_env_structure_factor(
351366
if not new_chi in already_tried_chi:
352367
working_unitcell = working_unitcell.change_chi(new_chi)
353368

354-
if varipeps_config.ctmrg_print_steps:
355-
debug_print(
356-
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {}.",
369+
if logger.isEnabledFor(logging.INFO):
370+
logger.info(
371+
"CTMRG (SF): Decreasing chi to %d since smallest SVD Norm was %.3e.",
357372
new_chi,
358373
norm_smallest_S,
359374
)
360375

361376
already_tried_chi.add(new_chi)
362-
363377
continue
364378

365379
if (
@@ -375,9 +389,9 @@ def calc_ctmrg_env_structure_factor(
375389
new_truncation_eps
376390
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
377391
):
378-
if varipeps_config.ctmrg_print_steps:
379-
debug_print(
380-
"CTMRG: Increasing SVD truncation eps to {}.",
392+
if logger.isEnabledFor(logging.INFO):
393+
logger.info(
394+
"CTMRG (SF): Increasing SVD truncation eps to %g.",
381395
new_truncation_eps,
382396
)
383397
varipeps_global_state.ctmrg_effective_truncation_eps = (

0 commit comments

Comments
 (0)