55from jax import jit , custom_vjp , vjp , tree_util
66from jax .lax import cond , while_loop
77import jax .debug as jdebug
8+ import logging
9+ import time
10+
11+ logger = logging .getLogger ("varipeps.ctmrg" )
812
913from varipeps import varipeps_config , varipeps_global_state
1014from varipeps .peps import PEPS_Tensor , PEPS_Unit_Cell
@@ -125,8 +129,8 @@ def _ctmrg_body_func_structure_factor(carry):
125129 measure = jnp .linalg .norm (corner_svd - last_corner_svd )
126130 converged = measure < eps
127131
128- if config . ctmrg_print_steps :
129- debug_print ( "CTMRG: { }: {}" , count , measure )
132+ if logger . isEnabledFor ( logging . DEBUG ) :
133+ jax . debug . callback ( lambda cnt , msr : logger . debug ( f "CTMRG: Step { cnt } : { msr } " ) , count , measure , ordered = True )
130134 if config .ctmrg_verbose_output :
131135 for ti , ctm_enum_i , diff in verbose_data :
132136 debug_print (
@@ -244,6 +248,7 @@ def calc_ctmrg_env_structure_factor(
244248 norm_smallest_S = jnp .nan
245249 already_tried_chi = {working_unitcell [0 , 0 ][0 ][0 ].chi }
246250
251+ t0 = time .perf_counter ()
247252 while True :
248253 tmp_count = 0
249254 corner_singular_vals = None
@@ -304,6 +309,17 @@ def calc_ctmrg_env_structure_factor(
304309 )
305310 )
306311
312+ if not converged and logger .isEnabledFor (logging .WARNING ):
313+ logger .warning (
314+ "CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" ,
315+ time .perf_counter () - t0 , end_count , norm_smallest_S
316+ )
317+ elif logger .isEnabledFor (logging .INFO ):
318+ logger .info (
319+ "CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" ,
320+ time .perf_counter () - t0 , end_count , norm_smallest_S
321+ )
322+
307323 current_truncation_eps = (
308324 varipeps_config .ctmrg_truncation_eps
309325 if varipeps_global_state .ctmrg_effective_truncation_eps is None
@@ -326,15 +342,14 @@ def calc_ctmrg_env_structure_factor(
326342 working_unitcell = working_unitcell .change_chi (new_chi )
327343 initial_unitcell = initial_unitcell .change_chi (new_chi )
328344
329- if varipeps_config . ctmrg_print_steps :
330- debug_print (
331- "CTMRG: Increasing chi to {} since smallest SVD Norm was {} ." ,
345+ if logger . isEnabledFor ( logging . INFO ) :
346+ logger . info (
347+ "CTMRG (SF) : Increasing chi to %d since smallest SVD Norm was %.3e ." ,
332348 new_chi ,
333349 norm_smallest_S ,
334350 )
335351
336352 already_tried_chi .add (new_chi )
337-
338353 continue
339354 elif (
340355 varipeps_config .ctmrg_heuristic_decrease_chi
@@ -351,15 +366,14 @@ def calc_ctmrg_env_structure_factor(
351366 if not new_chi in already_tried_chi :
352367 working_unitcell = working_unitcell .change_chi (new_chi )
353368
354- if varipeps_config . ctmrg_print_steps :
355- debug_print (
356- "CTMRG: Decreasing chi to {} since smallest SVD Norm was {} ." ,
369+ if logger . isEnabledFor ( logging . INFO ) :
370+ logger . info (
371+ "CTMRG (SF) : Decreasing chi to %d since smallest SVD Norm was %.3e ." ,
357372 new_chi ,
358373 norm_smallest_S ,
359374 )
360375
361376 already_tried_chi .add (new_chi )
362-
363377 continue
364378
365379 if (
@@ -375,9 +389,9 @@ def calc_ctmrg_env_structure_factor(
375389 new_truncation_eps
376390 <= varipeps_config .ctmrg_increase_truncation_eps_max_value
377391 ):
378- if varipeps_config . ctmrg_print_steps :
379- debug_print (
380- "CTMRG: Increasing SVD truncation eps to {} ." ,
392+ if logger . isEnabledFor ( logging . INFO ) :
393+ logger . info (
394+ "CTMRG (SF) : Increasing SVD truncation eps to %g ." ,
381395 new_truncation_eps ,
382396 )
383397 varipeps_global_state .ctmrg_effective_truncation_eps = (
0 commit comments