Skip to content
52 changes: 49 additions & 3 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
from collections.abc import Generator
from dataclasses import dataclass
from datetime import timedelta
Expand All @@ -22,13 +23,15 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.types import STEP_OUTPUT

if _RICH_AVAILABLE:
from rich import get_console, reconfigure
from rich.console import Console, RenderableType
from rich.live import _RefreshThread as _RichRefreshThread
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
from rich.progress_bar import ProgressBar as _RichProgressBar
from rich.style import Style
Expand Down Expand Up @@ -66,9 +69,49 @@ class CustomInfiniteTask(Task):
def time_remaining(self) -> Optional[float]:
return None

class _RefreshThread(_RichRefreshThread):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.refresh_cond = False
super().__init__(*args, **kwargs)

def run(self) -> None:
while not self.done.is_set():
if self.refresh_cond:
with self.live._lock:
self.live.refresh()
self.refresh_cond = False
time.sleep(0.005)

class CustomProgress(Progress):
"""Overrides ``Progress`` to support adding tasks that have an infinite total size."""

def start(self) -> None:
"""Starts the progress display.

Notes
-----
This override is needed to support the custom refresh thread.

"""
if self.live.auto_refresh:
self.live._refresh_thread = _RefreshThread(self.live, self.live.refresh_per_second)
self.live.auto_refresh = False
super().start()
if self.live._refresh_thread:
self.live.auto_refresh = True
self.live._refresh_thread.start()

def stop(self) -> None:
refresh_thread = self.live._refresh_thread
super().stop()
if refresh_thread:
refresh_thread.stop()
refresh_thread.join()

def soft_refresh(self) -> None:
if self.live.auto_refresh and isinstance(self.live._refresh_thread, _RefreshThread):
self.live._refresh_thread.refresh_cond = True

def add_task(
self,
description: str,
Expand Down Expand Up @@ -356,17 +399,20 @@ def _init_progress(self, trainer: "pl.Trainer") -> None:
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
auto_refresh=False,
auto_refresh=True,
disable=self.is_disabled,
console=self._console,
)
self.progress.start()
# progress has started
self._progress_stopped = False

def refresh(self) -> None:
def refresh(self, hard: bool = False) -> None:
if self.progress:
self.progress.refresh()
if hard or _IS_INTERACTIVE:
self.progress.refresh()
else:
self.progress.soft_refresh()

@override
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def test_rich_progress_bar_custom_theme():
_, kwargs = mocks["ProcessingSpeedColumn"].call_args
assert kwargs["style"] == theme.processing_speed

progress_bar.progress.live._refresh_thread.stop()
progress_bar.progress.live._refresh_thread.join()


@RunIf(rich=True)
def test_rich_progress_bar_keyboard_interrupt(tmp_path):
Expand Down Expand Up @@ -176,6 +179,8 @@ def configure_columns(self, trainer):
assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 2

progress_bar.progress.stop()


@RunIf(rich=True)
@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 3)]))
Expand Down Expand Up @@ -345,7 +350,8 @@ def training_step(self, *args, **kwargs):

for key in ("loss", "v_num", "train_loss"):
assert key in rendered[train_progress_bar_id][1]
assert key not in rendered[val_progress_bar_id][1]
if val_progress_bar_id in rendered:
assert key not in rendered[val_progress_bar_id][1]


def test_rich_progress_bar_metrics_fast_dev_run(tmp_path):
Expand All @@ -359,7 +365,8 @@ def test_rich_progress_bar_metrics_fast_dev_run(tmp_path):
val_progress_bar_id = progress_bar.val_progress_bar_id
rendered = progress_bar.progress.columns[-1]._renderable_cache
assert "v_num" not in rendered[train_progress_bar_id][1]
assert "v_num" not in rendered[val_progress_bar_id][1]
if val_progress_bar_id in rendered:
assert "v_num" not in rendered[val_progress_bar_id][1]


@RunIf(rich=True)
Expand Down
Loading