66import jax .util
77from jax .lax import cond , while_loop
88import jax .debug as jdebug
9+ import logging
10+ import time
11+
12+ logger = logging .getLogger ("varipeps.ctmrg" )
913
1014from varipeps import varipeps_config , varipeps_global_state
1115from varipeps .peps import PEPS_Tensor , PEPS_Unit_Cell
@@ -126,8 +130,8 @@ def _ctmrg_body_func_structure_factor(carry):
126130 measure = jnp .linalg .norm (corner_svd - last_corner_svd )
127131 converged = measure < eps
128132
129- if config . ctmrg_print_steps :
130- debug_print ( "CTMRG: { }: {}" , count , measure )
133+ if logger . isEnabledFor ( logging . DEBUG ) :
134+ jax . debug . callback ( lambda cnt , msr : logger . debug ( f "CTMRG: Step { cnt } : { msr } " ) , count , measure , ordered = True )
131135 if config .ctmrg_verbose_output :
132136 for ti , ctm_enum_i , diff in verbose_data :
133137 debug_print (
@@ -245,6 +249,7 @@ def calc_ctmrg_env_structure_factor(
245249 norm_smallest_S = jnp .nan
246250 already_tried_chi = {working_unitcell [0 , 0 ][0 ][0 ].chi }
247251
252+ t0 = time .perf_counter ()
248253 while True :
249254 tmp_count = 0
250255 corner_singular_vals = None
@@ -305,6 +310,17 @@ def calc_ctmrg_env_structure_factor(
305310 )
306311 )
307312
313+ if not converged and logger .isEnabledFor (logging .WARNING ):
314+ logger .warning (
315+ "CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" ,
316+ time .perf_counter () - t0 , end_count , norm_smallest_S
317+ )
318+ elif logger .isEnabledFor (logging .INFO ):
319+ logger .info (
320+ "CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" ,
321+ time .perf_counter () - t0 , end_count , norm_smallest_S
322+ )
323+
308324 current_truncation_eps = (
309325 varipeps_config .ctmrg_truncation_eps
310326 if varipeps_global_state .ctmrg_effective_truncation_eps is None
@@ -327,15 +343,14 @@ def calc_ctmrg_env_structure_factor(
327343 working_unitcell = working_unitcell .change_chi (new_chi )
328344 initial_unitcell = initial_unitcell .change_chi (new_chi )
329345
330- if varipeps_config . ctmrg_print_steps :
331- debug_print (
332- "CTMRG: Increasing chi to {} since smallest SVD Norm was {} ." ,
346+ if logger . isEnabledFor ( logging . INFO ) :
347+ logger . info (
348+ "CTMRG (SF) : Increasing chi to %d since smallest SVD Norm was %.3e ." ,
333349 new_chi ,
334350 norm_smallest_S ,
335351 )
336352
337353 already_tried_chi .add (new_chi )
338-
339354 continue
340355 elif (
341356 varipeps_config .ctmrg_heuristic_decrease_chi
@@ -352,15 +367,14 @@ def calc_ctmrg_env_structure_factor(
352367 if not new_chi in already_tried_chi :
353368 working_unitcell = working_unitcell .change_chi (new_chi )
354369
355- if varipeps_config . ctmrg_print_steps :
356- debug_print (
357- "CTMRG: Decreasing chi to {} since smallest SVD Norm was {} ." ,
370+ if logger . isEnabledFor ( logging . INFO ) :
371+ logger . info (
372+ "CTMRG (SF) : Decreasing chi to %d since smallest SVD Norm was %.3e ." ,
358373 new_chi ,
359374 norm_smallest_S ,
360375 )
361376
362377 already_tried_chi .add (new_chi )
363-
364378 continue
365379
366380 if (
@@ -376,9 +390,9 @@ def calc_ctmrg_env_structure_factor(
376390 new_truncation_eps
377391 <= varipeps_config .ctmrg_increase_truncation_eps_max_value
378392 ):
379- if varipeps_config . ctmrg_print_steps :
380- debug_print (
381- "CTMRG: Increasing SVD truncation eps to {} ." ,
393+ if logger . isEnabledFor ( logging . INFO ) :
394+ logger . info (
395+ "CTMRG (SF) : Increasing SVD truncation eps to %g ." ,
382396 new_truncation_eps ,
383397 )
384398 varipeps_global_state .ctmrg_effective_truncation_eps = (
0 commit comments