Skip to content

Commit 63d1956

Browse files
added tqdm logging
1 parent 829cde1 commit 63d1956

File tree

2 files changed

+105
-13
lines changed

2 files changed

+105
-13
lines changed

varipeps/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,24 @@ class VariPEPS_Config:
235235
Type of wavevector to be used (only positive/symmetric interval/...).
236236
slurm_restart_mode (:obj:`Slurm_Restart_Mode`):
237237
Mode of operation to restart slurm job if maximal runtime is reached.
238+
log_level_global (:obj:`LogLevel`):
239+
Global logging level.
240+
log_level_optimizer (:obj:`LogLevel`):
241+
Logging level for optimizer module.
242+
log_level_ctmrg (:obj:`LogLevel`):
243+
Logging level for CTMRG module.
244+
log_level_line_search (:obj:`LogLevel`):
245+
Logging level for line search module.
246+
log_level_expectation (:obj:`LogLevel`):
247+
Logging level for expectation value calculations.
248+
log_to_console (:obj:`bool`):
249+
Enable logging to console.
250+
log_to_file (:obj:`bool`):
251+
Enable logging to file.
252+
log_file (:obj:`str`):
253+
Filename for logging to file.
254+
log_tqdm (:obj:`bool`):
255+
Enable tqdm-based console logging.
238256
"""
239257

240258
# AD config
@@ -330,6 +348,7 @@ class VariPEPS_Config:
330348
log_to_file: bool = False
331349
log_file: str = "varipeps.log"
332350
log_step_summary_every_n: int = 1
351+
log_tqdm: bool = False #: Enable tqdm-based console logging
333352

334353
def update(self, name: str, value: Any) -> NoReturn:
335354
self.__setattr__(name, value)

varipeps/utils/logging_config.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,43 @@
55

66
from varipeps import config as _cfg_mod # uses the global config instance
77

8+
# --- Custom tqdm-based handlers ---
9+
10+
class TqdmUpdateHandler(logging.Handler):
11+
"""Updates a tqdm progress bar's postfix string instead of printing."""
12+
def __init__(self, pbar: Any):
13+
super().__init__()
14+
self.pbar = pbar
15+
16+
def emit(self, record: logging.LogRecord) -> None:
17+
try:
18+
msg = self.format(record)
19+
# Truncate to keep the bar compact
20+
self.pbar.set_postfix_str(str(msg), refresh=True)
21+
except Exception: # nosec - logging must never raise
22+
self.handleError(record)
23+
24+
25+
class TqdmWriteHandler(logging.Handler):
26+
"""Writes messages via tqdm.write (thread-safe)."""
27+
def emit(self, record: logging.LogRecord) -> None:
28+
try:
29+
msg = self.format(record)
30+
from tqdm import tqdm
31+
tqdm.write(str(msg))
32+
except Exception:
33+
self.handleError(record)
34+
35+
36+
class ExcludeLoggerFilter(logging.Filter):
37+
"""Exclude records whose logger name starts with a given prefix."""
38+
def __init__(self, prefix: str):
39+
super().__init__()
40+
self.prefix = prefix
41+
42+
def filter(self, record: logging.LogRecord) -> bool:
43+
return not record.name.startswith(self.prefix)
44+
845
_LOGGING_INITIALIZED = False
946

1047
def _to_py_log_level(level: Any) -> int:
@@ -34,25 +71,61 @@ def init_logging(cfg: Any | None = None) -> None:
3471
root.setLevel(_to_py_log_level(getattr(cfg, "log_level_global", logging.INFO)))
3572
root.propagate = False
3673

37-
# fmt = logging.Formatter(
38-
# fmt="%(asctime)s %(levelname)s %(name)s: %(message)s",
39-
# datefmt="%H:%M:%S",
40-
# )
41-
4274
fmt = logging.Formatter(
4375
fmt="%(asctime)s %(levelname)s %(message)s",
4476
datefmt="%Y-%m-%d %H:%M:%S",
4577
)
4678

47-
if getattr(cfg, "log_to_console", True):
48-
sh = logging.StreamHandler()
49-
sh.setFormatter(fmt)
50-
root.addHandler(sh)
79+
use_tqdm = bool(getattr(cfg, "log_tqdm", False))
80+
81+
if use_tqdm:
82+
fmt = logging.Formatter(fmt="%(message)s")
83+
# Console via tqdm.write for all varipeps loggers except optimizer
84+
tw = TqdmWriteHandler()
85+
tw.setFormatter(fmt)
86+
tw.addFilter(ExcludeLoggerFilter("varipeps.optimizer"))
87+
root.addHandler(tw)
88+
89+
# Preserve file logging if enabled
90+
if getattr(cfg, "log_to_file", False):
91+
fh = logging.FileHandler(getattr(cfg, "log_file", "varipeps.log"))
92+
fh.setFormatter(fmt)
93+
root.addHandler(fh)
94+
95+
# Optimizer uses a tqdm progress bar update handler
96+
opt_logger = logging.getLogger("varipeps.optimizer")
97+
for h in list(opt_logger.handlers):
98+
opt_logger.removeHandler(h)
99+
100+
from tqdm import tqdm
101+
# Create a lightweight bar that we only update the postfix for
102+
pbar = tqdm(total=0, position=0, leave=True, dynamic_ncols=True)
103+
104+
if pbar is not None:
105+
th = TqdmUpdateHandler(pbar)
106+
th.setFormatter(fmt)
107+
opt_logger.addHandler(th)
108+
109+
# Keep propagation so optimizer still logs to file handler if present,
110+
# while console is suppressed by the ExcludeLoggerFilter on root.
111+
opt_logger.propagate = True
112+
else:
113+
# Standard console/file logging
114+
if getattr(cfg, "log_to_console", True):
115+
sh = logging.StreamHandler()
116+
sh.setFormatter(fmt)
117+
root.addHandler(sh)
118+
119+
if getattr(cfg, "log_to_file", False):
120+
fh = logging.FileHandler(getattr(cfg, "log_file", "varipeps.log"))
121+
fh.setFormatter(fmt)
122+
root.addHandler(fh)
51123

52-
if getattr(cfg, "log_to_file", False):
53-
fh = logging.FileHandler(getattr(cfg, "log_file", "varipeps.log"))
54-
fh.setFormatter(fmt)
55-
root.addHandler(fh)
124+
# Ensure optimizer has no leftover tqdm handler from a previous init
125+
opt_logger = logging.getLogger("varipeps.optimizer")
126+
for h in list(opt_logger.handlers):
127+
opt_logger.removeHandler(h)
128+
opt_logger.propagate = True
56129

57130
# Per-module levels
58131
logging.getLogger("varipeps.optimizer").setLevel(

0 commit comments

Comments
 (0)