|
20 | 20 | from copy import copy |
21 | 21 | from functools import wraps |
22 | 22 | from importlib import import_module |
| 23 | +from textwrap import indent |
23 | 24 | from typing import Any, Callable, cast, TypeVar |
24 | 25 |
|
25 | 26 | import numpy as np |
@@ -52,25 +53,37 @@ def strtobool(val: Any) -> bool: |
52 | 53 | LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") |
53 | 54 | logger = logging.getLogger("torchrl") |
54 | 55 | logger.setLevel(getattr(logging, LOGGING_LEVEL)) |
55 | | -# Disable propagation to the root logger |
56 | 56 | logger.propagate = False |
57 | | -# Remove all attached handlers |
| 57 | +# Clear existing handlers |
58 | 58 | while logger.hasHandlers(): |
59 | 59 | logger.removeHandler(logger.handlers[0]) |
60 | 60 | stream_handlers = { |
61 | 61 | "stdout": sys.stdout, |
62 | 62 | "stderr": sys.stderr, |
63 | 63 | } |
64 | 64 | TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM") |
65 | | -if TORCHRL_CONSOLE_STREAM: |
66 | | - stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM] |
67 | | -else: |
68 | | - stream_handler = None |
69 | | -console_handler = logging.StreamHandler(stream=stream_handler) |
70 | | - |
71 | | -console_handler.setLevel(logging.INFO) |
72 | | -formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s") |
73 | | -console_handler.setFormatter(formatter) |
| 65 | +stream_handler = stream_handlers.get(TORCHRL_CONSOLE_STREAM, sys.stdout) |
| 66 | + |
| 67 | + |
| 68 | +# Create colored handler |
| 69 | +class _CustomFormatter(logging.Formatter): |
| 70 | + def format(self, record): |
| 71 | + # Format the initial part in green |
| 72 | + green_format = "\033[92m%(asctime)s [%(name)s][%(levelname)s]\033[0m" |
| 73 | + # Format the message part |
| 74 | + message_format = "%(message)s" |
| 75 | + # End marker in green |
| 76 | + end_marker = "\033[92m [END]\033[0m" |
| 77 | + # Combine all parts |
| 78 | + formatted_message = logging.Formatter( |
| 79 | + green_format + indent(message_format, " " * 4) + end_marker |
| 80 | + ).format(record) |
| 81 | + |
| 82 | + return formatted_message |
| 83 | + |
| 84 | + |
| 85 | +console_handler = logging.StreamHandler(stream_handler) |
| 86 | +console_handler.setFormatter(_CustomFormatter()) |
74 | 87 | logger.addHandler(console_handler) |
75 | 88 |
|
76 | 89 | VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) |
|
0 commit comments