Skip to content

Commit 77bfef4

Browse files
committed
Fix progress bar deadlock on DDP metrics computation.
1 parent 2f448e1 commit 77bfef4

File tree

4 files changed

+115
-5
lines changed

4 files changed

+115
-5
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,15 +646,14 @@ def _update_metrics(
646646
current: Optional[int] = None,
647647
total_batches: bool = False,
648648
) -> None:
649-
if not self.is_enabled or self._metric_component is None:
650-
return
651-
652649
if current is not None and not total_batches:
653650
total = self.total_train_batches
654651
if not self._should_update(current, total):
655652
return
656653

657654
metrics = self.get_metrics(trainer, pl_module)
655+
if not self.is_enabled or self._metric_component is None:
656+
return
658657
if self._metric_component:
659658
self._metric_component.update(metrics)
660659

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,9 @@ def on_train_batch_end(
282282

283283
@override
284284
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
285+
metrics = self.get_metrics(trainer, pl_module)
285286
if not self.train_progress_bar.disable:
286-
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
287+
self.train_progress_bar.set_postfix(metrics)
287288
if self._leave:
288289
self.train_progress_bar.close()
289290

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import datetime
1415
import pickle
1516
from collections import defaultdict
1617
from unittest import mock
17-
from unittest.mock import DEFAULT, Mock
18+
from unittest.mock import DEFAULT, Mock, patch
1819

1920
import pytest
2021
from tests_pytorch.helpers.runif import RunIf
@@ -26,6 +27,7 @@
2627
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
2728
from lightning.pytorch.loggers import CSVLogger
2829
from lightning.pytorch.loggers.logger import DummyLogger
30+
from lightning.pytorch.strategies import DDPStrategy
2931

3032

3133
@RunIf(rich=True)
@@ -605,3 +607,55 @@ def val_dataloader(self):
605607

606608
# This should not raise an AssertionError
607609
trainer.fit(model)
610+
611+
612+
def test_rich_progress_bar_ddp_deadlock(tmp_path):
613+
"""Tests that RichProgressBar doesn't deadlock when using DDP on train epoch end.
614+
615+
We used to have a bug where metrics were synced only on the rank 0 process. See
616+
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
617+
for more details.
618+
619+
"""
620+
RichProgressBar()
621+
622+
# We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True
623+
class MyModel(BoringModel):
624+
def training_step(self, batch, batch_idx):
625+
loss = super().training_step(batch, batch_idx)["loss"]
626+
self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True)
627+
return {"loss": loss}
628+
629+
model = MyModel()
630+
631+
# We need to mock these logger connector hooks, since these also attempt to sync metrics
632+
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
633+
def mock_on_epoch_end(self):
634+
pass
635+
636+
def mock_update_train_epoch_metrics(self):
637+
pass
638+
639+
with (
640+
patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end),
641+
patch(
642+
"lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics",
643+
mock_update_train_epoch_metrics,
644+
),
645+
):
646+
trainer = Trainer(
647+
default_root_dir=tmp_path,
648+
num_sanity_val_steps=0,
649+
max_epochs=1,
650+
val_check_interval=1,
651+
accelerator="cpu",
652+
devices=2,
653+
strategy=DDPStrategy(
654+
process_group_backend="gloo", # run on CPU
655+
timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail
656+
),
657+
enable_progress_bar=True,
658+
enable_model_summary=False,
659+
enable_checkpointing=False,
660+
)
661+
trainer.fit(model)

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import datetime
1415
import math
1516
import os
1617
import pickle
@@ -32,6 +33,7 @@
3233
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
3334
from lightning.pytorch.loggers import CSVLogger
3435
from lightning.pytorch.loggers.logger import DummyLogger
36+
from lightning.pytorch.strategies import DDPStrategy
3537
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3638

3739

@@ -859,3 +861,57 @@ def reset(self, total=None):
859861
assert 2 in val_bar.total_values, (
860862
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
861863
)
864+
865+
866+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
867+
def test_tqdm_progress_bar_ddp_deadlock(tmp_path):
868+
"""Tests that TQDMProgressBar doesn't deadlock when using DDP on train epoch end.
869+
870+
We used to have a bug where metrics were synced only on the rank 0 process. See
871+
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
872+
for more details.
873+
874+
"""
875+
pbar = TQDMProgressBar()
876+
877+
# We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True
878+
class MyModel(BoringModel):
879+
def training_step(self, batch, batch_idx):
880+
loss = super().training_step(batch, batch_idx)["loss"]
881+
self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True)
882+
return {"loss": loss}
883+
884+
model = MyModel()
885+
886+
# We need to mock these logger connector hooks, since these also attempt to sync metrics
887+
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
888+
def mock_on_epoch_end(self):
889+
pass
890+
891+
def mock_update_train_epoch_metrics(self):
892+
pass
893+
894+
with (
895+
patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end),
896+
patch(
897+
"lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics",
898+
mock_update_train_epoch_metrics,
899+
),
900+
):
901+
trainer = Trainer(
902+
default_root_dir=tmp_path,
903+
num_sanity_val_steps=0,
904+
max_epochs=1,
905+
val_check_interval=1,
906+
accelerator="cpu",
907+
devices=2,
908+
strategy=DDPStrategy(
909+
process_group_backend="gloo", # run on CPU
910+
timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail
911+
),
912+
callbacks=[pbar],
913+
enable_progress_bar=True,
914+
enable_model_summary=False,
915+
enable_checkpointing=False,
916+
)
917+
trainer.fit(model)

0 commit comments

Comments
 (0)