55from jax import jit , custom_vjp , vjp , tree_util
66from jax .lax import cond , while_loop
77import jax .debug as jdebug
8+ import logging
9+ import time
10+
11+ logger = logging .getLogger ("varipeps.ctmrg" )
812
913from varipeps import varipeps_config , varipeps_global_state
1014from varipeps .peps import PEPS_Tensor , PEPS_Tensor_Split_Transfer , PEPS_Unit_Cell
@@ -515,9 +519,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config):
515519 eps ,
516520 config ,
517521 )
518-
519- if config .ctmrg_print_steps :
520- debug_print ("CTMRG: {}: {}" , count , measure )
522+ if logger .isEnabledFor (logging .DEBUG ):
523+ jax .debug .callback (lambda cnt , msr : logger .debug (f"CTMRG: Step { cnt } : { msr } " ), count , measure , ordered = True )
521524 if config .ctmrg_verbose_output :
522525 jax .debug .callback (print_verbose , verbose_data , ordered = True )
523526
@@ -620,7 +623,7 @@ def calc_ctmrg_env(
620623 best_norm_smallest_S = None
621624 best_truncation_eps = None
622625 have_been_increased = False
623-
626+ t0 = time . perf_counter ()
624627 while True :
625628 tmp_count = 0
626629 corner_singular_vals = None
@@ -720,6 +723,11 @@ def calc_ctmrg_env(
720723 else :
721724 converged = False
722725 end_count = tmp_count
726+ if logger .isEnabledFor (logging .INFO ):
727+ if logger .isEnabledFor (logging .WARNING ) and not converged :
728+ logger .warning ("CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" , time .perf_counter () - t0 , end_count , norm_smallest_S )
729+ else :
730+ logger .info ("CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)" , time .perf_counter () - t0 , end_count , norm_smallest_S )
723731
724732 if converged and (
725733 working_unitcell [0 , 0 ][0 ][0 ].chi > best_chi or best_result is None
@@ -751,9 +759,9 @@ def calc_ctmrg_env(
751759 working_unitcell = working_unitcell .change_chi (new_chi )
752760 initial_unitcell = initial_unitcell .change_chi (new_chi )
753761
754- if varipeps_config . ctmrg_print_steps :
755- debug_print (
756- "CTMRG: Increasing chi to {} since smallest SVD Norm was {}." ,
762+ if logger . isEnabledFor ( logging . INFO ) :
763+ logger . info (
764+ "Increasing chi to {} since smallest SVD Norm was {}." ,
757765 new_chi ,
758766 norm_smallest_S ,
759767 )
@@ -785,9 +793,9 @@ def calc_ctmrg_env(
785793 if not new_chi in already_tried_chi :
786794 working_unitcell = working_unitcell .change_chi (new_chi )
787795
788- if varipeps_config . ctmrg_print_steps :
789- debug_print (
790- "CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge." ,
796+ if logger . isEnabledFor ( logging . INFO ) :
797+ logger . info (
798+ "Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge." ,
791799 new_chi ,
792800 norm_smallest_S ,
793801 )
@@ -809,9 +817,9 @@ def calc_ctmrg_env(
809817 new_truncation_eps
810818 <= varipeps_config .ctmrg_increase_truncation_eps_max_value
811819 ):
812- if varipeps_config . ctmrg_print_steps :
813- debug_print (
814- "CTMRG: Increasing SVD truncation eps to {}." ,
820+ if logger . isEnabledFor ( logging . INFO ) :
821+ logger . info (
822+ "Increasing SVD truncation eps to {}." ,
815823 new_truncation_eps ,
816824 )
817825 varipeps_global_state .ctmrg_effective_truncation_eps = (
@@ -937,8 +945,8 @@ def _ctmrg_rev_while_body(carry):
937945
938946 count += 1
939947
940- if config . ad_custom_print_steps :
941- debug_print ( "Custom VJP: {}: {}" , count , measure )
948+ if logger . isEnabledFor ( logging . DEBUG ) :
949+ jax . debug . callback ( lambda cnt , msr : logger . debug ( f "Custom VJP: Step { cnt } , Measure { msr } " ) , count , measure , ordered = True )
942950 if config .ad_custom_verbose_output :
943951 jax .debug .callback (print_verbose , verbose_data , ordered = True , ad = True )
944952
@@ -1014,12 +1022,14 @@ def calc_ctmrg_env_rev(
10141022
10151023 varipeps_global_state .ctmrg_effective_truncation_eps = last_truncation_eps
10161024
1025+ if logger .isEnabledFor (logging .INFO ):
1026+ t0 = time .perf_counter ()
10171027 t_bar , converged , end_count = _ctmrg_rev_workhorse (
10181028 peps_tensors , new_unitcell , unitcell_bar , varipeps_config , varipeps_global_state
10191029 )
10201030
10211031 varipeps_global_state .ctmrg_effective_truncation_eps = None
1022-
1032+ debug_print ( "Custom VJP: Converged: {}, Steps: {}" , converged , end_count )
10231033 if end_count == varipeps_config .ad_custom_max_steps and not converged :
10241034 raise CTMRGGradientNotConvergedError
10251035
0 commit comments