Skip to content

Commit 795db9e

Browse files
Added Loggers for optimizer, routine and line search removed tqdm
1 parent 06735ff commit 795db9e

File tree

8 files changed

+157
-62
lines changed

8 files changed

+157
-62
lines changed

varipeps/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,4 @@
2020

2121
jax_config.update("jax_enable_x64", True)
2222

23-
from tqdm_loggable.tqdm_logging import tqdm_logging
24-
import datetime
25-
26-
tqdm_logging.set_log_rate(datetime.timedelta(seconds=60))
27-
28-
del datetime
29-
del tqdm_logging
3023
del jax_config

varipeps/config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from dataclasses import dataclass
22
from enum import Enum, IntEnum, auto, unique
3+
from typing import TypeVar, Tuple, Any, Type, NoReturn
4+
import logging
35

46
import numpy as np
57

68
from jax.tree_util import register_pytree_node_class
79

8-
from typing import TypeVar, Tuple, Any, Type, NoReturn
910

1011
T_VariPEPS_Config = TypeVar("T_VariPEPS_Config", bound="VariPEPS_Config")
1112

@@ -54,6 +55,15 @@ class Slurm_Restart_Mode(IntEnum):
5455
AUTOMATIC_RESTART = auto() #: Write restart script and start new slurm job with it
5556

5657

58+
@unique
59+
class LogLevel(IntEnum):
60+
OFF = 0
61+
ERROR = logging.ERROR
62+
WARNING = logging.WARNING
63+
INFO = logging.INFO
64+
DEBUG = logging.DEBUG
65+
66+
5767
@dataclass
5868
@register_pytree_node_class
5969
class VariPEPS_Config:
@@ -310,6 +320,16 @@ class VariPEPS_Config:
310320
# Slurm
311321
slurm_restart_mode: Slurm_Restart_Mode = Slurm_Restart_Mode.WRITE_NEED_RESTART_FILE
312322

323+
# Logging configuration
324+
log_level_global: LogLevel = LogLevel.INFO
325+
log_level_optimizer: LogLevel = LogLevel.INFO
326+
log_level_ctmrg: LogLevel = LogLevel.INFO
327+
log_level_line_search: LogLevel = LogLevel.INFO
328+
log_to_console: bool = True
329+
log_to_file: bool = False
330+
log_file: str = "varipeps.log"
331+
log_step_summary_every_n: int = 1
332+
313333
def update(self, name: str, value: Any) -> NoReturn:
314334
self.__setattr__(name, value)
315335

@@ -395,6 +415,7 @@ class ConfigModuleWrapper:
395415
"Projector_Method",
396416
"Wavevector_Type",
397417
"Slurm_Restart_Mode",
418+
"LogLevel",
398419
"VariPEPS_Config",
399420
"config",
400421
}

varipeps/ctmrg/routine.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from jax import jit, custom_vjp, vjp, tree_util
66
from jax.lax import cond, while_loop
77
import jax.debug as jdebug
8+
import logging
9+
import time
10+
11+
logger = logging.getLogger("varipeps.ctmrg")
812

913
from varipeps import varipeps_config, varipeps_global_state
1014
from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell
@@ -515,9 +519,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config):
515519
eps,
516520
config,
517521
)
518-
519-
if config.ctmrg_print_steps:
520-
debug_print("CTMRG: {}: {}", count, measure)
522+
if logger.isEnabledFor(logging.DEBUG):
523+
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
521524
if config.ctmrg_verbose_output:
522525
jax.debug.callback(print_verbose, verbose_data, ordered=True)
523526

@@ -620,7 +623,7 @@ def calc_ctmrg_env(
620623
best_norm_smallest_S = None
621624
best_truncation_eps = None
622625
have_been_increased = False
623-
626+
t0 = time.perf_counter()
624627
while True:
625628
tmp_count = 0
626629
corner_singular_vals = None
@@ -720,6 +723,11 @@ def calc_ctmrg_env(
720723
else:
721724
converged = False
722725
end_count = tmp_count
726+
if logger.isEnabledFor(logging.INFO):
727+
if logger.isEnabledFor(logging.WARNING) and not converged:
728+
logger.warning("CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", time.perf_counter() - t0, end_count, norm_smallest_S)
729+
else:
730+
logger.info("CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", time.perf_counter() - t0, end_count, norm_smallest_S)
723731

724732
if converged and (
725733
working_unitcell[0, 0][0][0].chi > best_chi or best_result is None
@@ -751,9 +759,9 @@ def calc_ctmrg_env(
751759
working_unitcell = working_unitcell.change_chi(new_chi)
752760
initial_unitcell = initial_unitcell.change_chi(new_chi)
753761

754-
if varipeps_config.ctmrg_print_steps:
755-
debug_print(
756-
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
762+
if logger.isEnabledFor(logging.INFO):
763+
logger.info(
764+
"Increasing chi to {} since smallest SVD Norm was {}.",
757765
new_chi,
758766
norm_smallest_S,
759767
)
@@ -785,9 +793,9 @@ def calc_ctmrg_env(
785793
if not new_chi in already_tried_chi:
786794
working_unitcell = working_unitcell.change_chi(new_chi)
787795

788-
if varipeps_config.ctmrg_print_steps:
789-
debug_print(
790-
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge.",
796+
if logger.isEnabledFor(logging.INFO):
797+
logger.info(
798+
"Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge.",
791799
new_chi,
792800
norm_smallest_S,
793801
)
@@ -809,9 +817,9 @@ def calc_ctmrg_env(
809817
new_truncation_eps
810818
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
811819
):
812-
if varipeps_config.ctmrg_print_steps:
813-
debug_print(
814-
"CTMRG: Increasing SVD truncation eps to {}.",
820+
if logger.isEnabledFor(logging.INFO):
821+
logger.info(
822+
"Increasing SVD truncation eps to {}.",
815823
new_truncation_eps,
816824
)
817825
varipeps_global_state.ctmrg_effective_truncation_eps = (
@@ -937,8 +945,8 @@ def _ctmrg_rev_while_body(carry):
937945

938946
count += 1
939947

940-
if config.ad_custom_print_steps:
941-
debug_print("Custom VJP: {}: {}", count, measure)
948+
if logger.isEnabledFor(logging.DEBUG):
949+
jax.debug.callback(lambda cnt, msr: logger.debug(f"Custom VJP: Step {cnt}, Measure {msr}"), count, measure, ordered=True)
942950
if config.ad_custom_verbose_output:
943951
jax.debug.callback(print_verbose, verbose_data, ordered=True, ad=True)
944952

@@ -1014,12 +1022,14 @@ def calc_ctmrg_env_rev(
10141022

10151023
varipeps_global_state.ctmrg_effective_truncation_eps = last_truncation_eps
10161024

1025+
if logger.isEnabledFor(logging.INFO):
1026+
t0 = time.perf_counter()
10171027
t_bar, converged, end_count = _ctmrg_rev_workhorse(
10181028
peps_tensors, new_unitcell, unitcell_bar, varipeps_config, varipeps_global_state
10191029
)
10201030

10211031
varipeps_global_state.ctmrg_effective_truncation_eps = None
1022-
1032+
debug_print("Custom VJP: Converged: {}, Steps: {}", converged, end_count)
10231033
if end_count == varipeps_config.ad_custom_max_steps and not converged:
10241034
raise CTMRGGradientNotConvergedError
10251035

varipeps/optimization/line_search.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import enum
22

3-
from tqdm_loggable.auto import tqdm
43

54
import jax
65
import jax.numpy as jnp
@@ -14,6 +13,9 @@
1413
from varipeps.expectation import Expectation_Model
1514
from varipeps.mapping import Map_To_PEPS_Model
1615
from varipeps.utils.debug_print import debug_print
16+
import logging
17+
18+
logger = logging.getLogger("varipeps.line_search")
1719

1820
from .inner_function import (
1921
calc_ctmrg_expectation,
@@ -443,6 +445,7 @@ def line_search(
443445
additional_input,
444446
enforce_elementwise_convergence=enforce_elementwise_convergence,
445447
)
448+
logger.info("🔎 Line search step %d, E=%.6f, alpha=%.3e", count + 1, new_value, alpha)
446449

447450
if new_unitcell[0, 0][0][0].chi > unitcell[0, 0][0][0].chi:
448451
tmp_value = current_value
@@ -463,10 +466,11 @@ def line_search(
463466
else:
464467
unitcell = unitcell.change_chi(new_unitcell[0, 0][0][0].chi)
465468

466-
debug_print(
467-
"Line search: Recalculate original unitcell with higher chi {}.",
468-
new_unitcell[0, 0][0][0].chi,
469-
)
469+
if logger.isEnabledFor(logging.DEBUG):
470+
logger.debug(
471+
"Line search: Recalculate original unitcell with higher chi %s.",
472+
new_unitcell[0, 0][0][0].chi,
473+
)
470474

471475
if varipeps_config.ad_use_custom_vjp:
472476
(
@@ -534,6 +538,7 @@ def line_search(
534538
additional_input,
535539
calc_preconverged=True,
536540
)
541+
tqdm.write(f"Line search step {count+1}, E={new_value:<.6f}, alpha={alpha:.4f}")
537542
new_gradient = [elem.conj() for elem in new_gradient_seq]
538543

539544
if new_unitcell[0, 0][0][0].chi > unitcell[0, 0][0][0].chi:
@@ -554,11 +559,11 @@ def line_search(
554559
) = cache_original_unitcell[new_unitcell[0, 0][0][0].chi]
555560
else:
556561
unitcell = unitcell.change_chi(new_unitcell[0, 0][0][0].chi)
557-
558-
debug_print(
559-
"Line search: Recalculate original unitcell with higher chi {}.",
560-
new_unitcell[0, 0][0][0].chi,
561-
)
562+
if logger.isEnabledFor(logging.DEBUG):
563+
logger.debug(
564+
"Line search: Recalculate original unitcell with higher chi %s.",
565+
new_unitcell[0, 0][0][0].chi,
566+
)
562567

563568
if varipeps_config.ad_use_custom_vjp:
564569
(
@@ -1002,7 +1007,7 @@ def line_search(
10021007
)
10031008

10041009
if alpha <= 0:
1005-
tqdm.write("Found negative alpha in secant operation!")
1010+
logger.warning("Found negative alpha in secant operation!")
10061011

10071012
hz_secant_alpha = alpha
10081013

varipeps/optimization/optimizer.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
from scipy.optimize import OptimizeResult
1212

13-
from tqdm_loggable.auto import tqdm
14-
1513
import h5py
1614

1715
import numpy as np
@@ -22,6 +20,10 @@
2220
from jax.lax import scan
2321
from jax.flatten_util import ravel_pytree
2422

23+
import logging
24+
25+
logger = logging.getLogger("varipeps.optimizer")
26+
2527
from varipeps import varipeps_config, varipeps_global_state
2628
from varipeps.config import Optimizing_Methods, Slurm_Restart_Mode
2729
from varipeps.peps import PEPS_Unit_Cell
@@ -603,7 +605,7 @@ def random_noise(a):
603605
slurm_restart_written = False
604606
slurm_new_job_id = None
605607

606-
with tqdm(desc="Optimizing PEPS state", initial=count) as pbar:
608+
if True:
607609
while count < varipeps_config.optimizer_max_steps:
608610
runtime_start = time.perf_counter()
609611

@@ -697,9 +699,6 @@ def random_noise(a):
697699
max_trunc_error_list[random_noise_retries] = []
698700
step_runtime[random_noise_retries] = []
699701

700-
pbar.reset()
701-
pbar.refresh()
702-
703702
continue
704703

705704
if working_unitcell[0, 0][0][0].chi != chi_before_ctmrg:
@@ -762,7 +761,7 @@ def random_noise(a):
762761
signal_reset_descent_dir = False
763762

764763
if _scalar_descent_grad(descent_dir, working_gradient) > 0:
765-
tqdm.write("Found bad descent dir. Reset to negative gradient!")
764+
logger.warning("Found bad descent dir. Reset to negative gradient!")
766765
descent_dir = [-elem for elem in working_gradient]
767766

768767
conv = jnp.linalg.norm(ravel_pytree(working_gradient)[0])
@@ -811,7 +810,7 @@ def random_noise(a):
811810
is Projector_Method.HALF
812811
)
813812
):
814-
tqdm.write(
813+
logger.warning(
815814
"Convergence is not sufficient. Retry with some random noise on best result."
816815
)
817816

@@ -1017,21 +1016,15 @@ def random_noise(a):
10171016

10181017
count += 1
10191018

1020-
pbar.update()
1021-
pbar.set_postfix(
1022-
{
1023-
"Energy": f"{working_value:0.10f}",
1024-
"Retries": random_noise_retries,
1025-
"Convergence": f"{conv:0.8f}",
1026-
"Line search step": (
1027-
f"{linesearch_step:0.8f}"
1028-
if linesearch_step is not None
1029-
else "0"
1030-
),
1031-
"Max. trunc. err.": f"{max_trunc_error:0.8g}",
1032-
}
1019+
logger.info(
1020+
"📉 Step %d | Energy: %0.10f | Retries: %d | Conv: %0.8f | Line search step: %s | Max. trunc. err.: %0.8g",
1021+
count,
1022+
working_value,
1023+
random_noise_retries,
1024+
conv,
1025+
f"{float(linesearch_step):0.8f}" if linesearch_step is not None else "None",
1026+
max_trunc_error,
10331027
)
1034-
pbar.refresh()
10351028

10361029
if count % varipeps_config.optimizer_autosave_step_count == 0:
10371030
_autosave_wrapper(
@@ -1187,7 +1180,7 @@ def random_noise(a):
11871180
slurm_data["WorkDir"],
11881181
)
11891182
if slurm_new_job_id is None:
1190-
tqdm.write(
1183+
logger.error(
11911184
"Failed to start new Slurm job or parse its job id."
11921185
)
11931186
break

varipeps/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from . import projector_dict
55
from . import slurm
66
from . import svd
7+
from . import logging_config

varipeps/utils/debug_print.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
import functools
2+
import logging
23

3-
from tqdm_loggable.auto import tqdm
44

55
import jax.debug as jdebug
66
from jax._src.debugging import formatter
77

8-
# Adapting function from jax.debug to work with tqdm
8+
logger = logging.getLogger("varipeps.ctmrg")
99

1010

1111
def _format_print_callback(fmt: str, *args, **kwargs):
12-
tqdm.write(fmt.format(*args, **kwargs))
12+
# Send to logger (respects per-module levels/handlers)
13+
logger.debug(fmt.format(*args, **kwargs))
1314

1415

1516
def debug_print(fmt: str, *args, ordered: bool = True, **kwargs) -> None:
1617
"""
1718
Prints values and works in staged out JAX functions.
1819
19-
Function adapted from :obj:`jax.debug.print` to work with tqdm. See there
20+
Function adapted from :obj:`jax.debug.print` to work with logger. See there
2021
for original authors and function.
2122
2223
Args:

0 commit comments

Comments
 (0)