66import jax .util
77from jax .lax import cond , while_loop
88import jax .debug as jdebug
9+ import logging
10+ import time
11+
12+ logger = logging .getLogger ("varipeps.ctmrg" )
913
1014from varipeps import varipeps_config , varipeps_global_state
1115from varipeps .peps import PEPS_Tensor , PEPS_Tensor_Split_Transfer , PEPS_Unit_Cell
@@ -516,9 +520,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config):
516520 eps ,
517521 config ,
518522 )
519-
520- if config .ctmrg_print_steps :
521- debug_print ("CTMRG: {}: {}" , count , measure )
523+ if logger .isEnabledFor (logging .DEBUG ):
524+ jax .debug .callback (lambda cnt , msr : logger .debug (f"CTMRG: Step { cnt } : { msr } " ), count , measure , ordered = True )
522525 if config .ctmrg_verbose_output :
523526 jax .debug .callback (print_verbose , verbose_data , ordered = True )
524527
@@ -621,7 +624,7 @@ def calc_ctmrg_env(
621624 best_norm_smallest_S = None
622625 best_truncation_eps = None
623626 have_been_increased = False
624-
627+ t0 = time . perf_counter ()
625628 while True :
626629 tmp_count = 0
627630 corner_singular_vals = None
@@ -721,6 +724,11 @@ def calc_ctmrg_env(
721724 else :
722725 converged = False
723726 end_count = tmp_count
727+ if logger .isEnabledFor (logging .INFO ):
728+ if logger .isEnabledFor (logging .WARNING ) and not converged :
729+ logger .warning ("CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" , time .perf_counter () - t0 , end_count , norm_smallest_S )
730+ else :
731+ logger .info ("CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" , time .perf_counter () - t0 , end_count , norm_smallest_S )
724732
725733 if converged and (
726734 working_unitcell [0 , 0 ][0 ][0 ].chi > best_chi or best_result is None
@@ -752,9 +760,9 @@ def calc_ctmrg_env(
752760 working_unitcell = working_unitcell .change_chi (new_chi )
753761 initial_unitcell = initial_unitcell .change_chi (new_chi )
754762
755- if varipeps_config . ctmrg_print_steps :
756- debug_print (
757- "CTMRG: Increasing chi to {} since smallest SVD Norm was {}." ,
763+ if logger . isEnabledFor ( logging . INFO ) :
764+ logger . info (
765+ "Increasing chi to {} since smallest SVD Norm was {}." ,
758766 new_chi ,
759767 norm_smallest_S ,
760768 )
@@ -786,9 +794,9 @@ def calc_ctmrg_env(
786794 if not new_chi in already_tried_chi :
787795 working_unitcell = working_unitcell .change_chi (new_chi )
788796
789- if varipeps_config . ctmrg_print_steps :
790- debug_print (
791- "CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge." ,
797+ if logger . isEnabledFor ( logging . INFO ) :
798+ logger . info (
799+ "Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge." ,
792800 new_chi ,
793801 norm_smallest_S ,
794802 )
@@ -810,9 +818,9 @@ def calc_ctmrg_env(
810818 new_truncation_eps
811819 <= varipeps_config .ctmrg_increase_truncation_eps_max_value
812820 ):
813- if varipeps_config . ctmrg_print_steps :
814- debug_print (
815- "CTMRG: Increasing SVD truncation eps to {}." ,
821+ if logger . isEnabledFor ( logging . INFO ) :
822+ logger . info (
823+ "Increasing SVD truncation eps to {}." ,
816824 new_truncation_eps ,
817825 )
818826 varipeps_global_state .ctmrg_effective_truncation_eps = (
@@ -938,8 +946,8 @@ def _ctmrg_rev_while_body(carry):
938946
939947 count += 1
940948
941- if config . ad_custom_print_steps :
942- debug_print ( "Custom VJP: {}: {}" , count , measure )
949+ if logger . isEnabledFor ( logging . DEBUG ) :
950+ jax . debug . callback ( lambda cnt , msr : logger . debug ( f "Custom VJP: Step { cnt } , Measure { msr } " ) , count , measure , ordered = True )
943951 if config .ad_custom_verbose_output :
944952 jax .debug .callback (print_verbose , verbose_data , ordered = True , ad = True )
945953
@@ -1015,12 +1023,14 @@ def calc_ctmrg_env_rev(
10151023
10161024 varipeps_global_state .ctmrg_effective_truncation_eps = last_truncation_eps
10171025
1026+ if logger .isEnabledFor (logging .INFO ):
1027+ t0 = time .perf_counter ()
10181028 t_bar , converged , end_count = _ctmrg_rev_workhorse (
10191029 peps_tensors , new_unitcell , unitcell_bar , varipeps_config , varipeps_global_state
10201030 )
10211031
10221032 varipeps_global_state .ctmrg_effective_truncation_eps = None
1023-
1033+ debug_print ( "Custom VJP: Converged: {}, Steps: {}" , converged , end_count )
10241034 if end_count == varipeps_config .ad_custom_max_steps and not converged :
10251035 raise CTMRGGradientNotConvergedError
10261036
0 commit comments