From 574dcb603a6e5b0e50a41e728aac888ce37ccff7 Mon Sep 17 00:00:00 2001 From: Iruos8805 Date: Sun, 2 Nov 2025 22:59:24 +0530 Subject: [PATCH 1/3] Add on_checkpoint_write_end callback hook and tests --- docs/source-pytorch/conf.py | 1 + docs/source-pytorch/extensions/callbacks.rst | 6 ++ src/lightning/pytorch/CHANGELOG.md | 2 + src/lightning/pytorch/callbacks/callback.py | 12 +++ src/lightning/pytorch/trainer/call.py | 23 +++++ src/lightning/pytorch/trainer/trainer.py | 2 + .../test_checkpoint_write_callback.py | 99 +++++++++++++++++++ .../trainer/logging_/test_logger_connector.py | 1 + 8 files changed, 146 insertions(+) create mode 100644 tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index b2a1749bd202f..93501d4d2af50 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -430,6 +430,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: ("py:func", "lightning.pytorch.callbacks.RichProgressBar.configure_columns"), ("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_load_checkpoint"), ("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_save_checkpoint"), + ("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_checkpoint_write_end"), ("py:class", "lightning.pytorch.callbacks.checkpoint.Checkpoint"), ("py:meth", "lightning.pytorch.callbacks.progress.progress_bar.ProgressBar.get_metrics"), ("py:class", "lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme"), diff --git a/docs/source-pytorch/extensions/callbacks.rst b/docs/source-pytorch/extensions/callbacks.rst index 7ed285591c4dc..729111a7578cb 100644 --- a/docs/source-pytorch/extensions/callbacks.rst +++ b/docs/source-pytorch/extensions/callbacks.rst @@ -344,6 +344,12 @@ on_load_checkpoint .. automethod:: lightning.pytorch.callbacks.Callback.on_load_checkpoint :noindex: +on_checkpoint_write_end +^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: lightning.pytorch.callbacks.Callback.on_checkpoint_write_end + :noindex: + on_before_backward ^^^^^^^^^^^^^^^^^^ diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 03c353508e390..9b0686ef51b02 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) +- Added `Callback.on_checkpoint_write_end` hook that triggers after checkpoint files are fully written to disk ([#XXXXX](https://github.com/Lightning-AI/pytorch-lightning/pull/XXXXX)) + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 3bfb609465a83..f1d25b8db0636 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -271,6 +271,18 @@ def on_load_checkpoint( """ + def on_checkpoint_write_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", filepath: str) -> None: + r"""Called after a checkpoint file has been fully written to disk. + + Use this hook to perform any post-save actions such as logging, uploading, or cleanup. + + Args: + trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance. + pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance. + filepath: The path to the checkpoint file that was written. + + """ + def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: """Called before ``loss.backward()``.""" diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 77536cdc16b33..81ddc3549a80b 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -295,6 +295,29 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s pl_module._current_fx_name = prev_fx_name +def _call_callbacks_on_checkpoint_write_end(trainer: "pl.Trainer", filepath: str) -> None: + """Called after a checkpoint file is written, calls every callback's `on_checkpoint_write_end` hook.""" + + pl_module = trainer.lightning_module + if pl_module: + prev_fx_name = pl_module._current_fx_name + pl_module._current_fx_name = "on_checkpoint_write_end" + + for callback in trainer.callbacks: + try: + with trainer.profiler.profile(f"[Callback]{callback.state_key}.on_checkpoint_write_end"): + callback.on_checkpoint_write_end(trainer, trainer.lightning_module, filepath) + except (KeyboardInterrupt, SystemExit): + raise + except Exception as e: + rank_zero_warn( + f"Exception raised in `on_checkpoint_write_end` of callback `{callback.__class__.__name__}`: {e}", + ) + + if pl_module: + pl_module._current_fx_name = prev_fx_name + + def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..01a305ba897c0 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1413,6 +1413,8 @@ def save_checkpoint( self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options) self.strategy.barrier("Trainer.save_checkpoint") + call._call_callbacks_on_checkpoint_write_end(self, filepath) + """ State properties """ diff --git a/tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py b/tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py new file mode 100644 index 0000000000000..a0d840e9de95c --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py @@ -0,0 +1,99 @@ +import os + +import torch + +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + + +class CheckpointWriteEndCallback(Callback): + def __init__(self): + self.called = False + self.filepath = None + self.file_existed = False + self.checkpoint_valid = False + + def on_checkpoint_write_end(self, trainer, pl_module, filepath): + """Test hook should trigger after checkpoint is written.""" + self.called = True + self.filepath = str(filepath) + self.file_existed = os.path.exists(filepath) + + try: + checkpoint = torch.load(filepath, map_location="cpu") + self.checkpoint_valid = "state_dict" in checkpoint + except Exception: + self.checkpoint_valid = False + + +def test_on_checkpoint_write_end_called(tmp_path): + """Test that on_checkpoint_write_end is triggered after checkpoint saving.""" + model = BoringModel() + callback = CheckpointWriteEndCallback() + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False) + + trainer.fit(model) + checkpoint_path = tmp_path / "test_checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + assert checkpoint_path.exists() + assert callback.called + assert callback.file_existed + assert callback.checkpoint_valid + assert callback.filepath == str(checkpoint_path) + + +def test_on_checkpoint_write_end_exception_safe(tmp_path): + """Test that callback exceptions don’t block others.""" + model = BoringModel() + + class FailingCallback(Callback): + def on_checkpoint_write_end(self, trainer, pl_module, filepath): + raise RuntimeError("Intentional error") + + class SuccessCallback(Callback): + def __init__(self): + self.called = False + + def on_checkpoint_write_end(self, trainer, pl_module, filepath): + self.called = True + + fail_cb = FailingCallback() + success_cb = SuccessCallback() + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[fail_cb, success_cb], logger=False) + + trainer.fit(model) + checkpoint_path = tmp_path / "test_checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + assert checkpoint_path.exists() + assert success_cb.called + + +def test_checkpoint_file_accessibility(tmp_path): + """Test that checkpoint is readable during callback execution.""" + model = BoringModel() + + class FileAccessCallback(Callback): + def __init__(self): + self.can_read = False + self.valid = False + + def on_checkpoint_write_end(self, trainer, pl_module, filepath): + try: + ckpt = torch.load(filepath, map_location="cpu") + self.can_read = True + self.valid = "state_dict" in ckpt + except (OSError, RuntimeError): + pass + + callback = FileAccessCallback() + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False) + + trainer.fit(model) + checkpoint_path = tmp_path / "test_checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + assert checkpoint_path.exists() + assert callback.can_read + assert callback.valid diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index d3d355edb003b..6d893ba3b9d6b 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -54,6 +54,7 @@ def test_fx_validator(): "on_sanity_check_start", "state_dict", "on_save_checkpoint", + "on_checkpoint_write_end", "on_test_batch_end", "on_test_batch_start", "on_test_end", From 5e26eab3ff35767c2a650297e31de55f9b470b64 Mon Sep 17 00:00:00 2001 From: Iruos8805 Date: Tue, 4 Nov 2025 11:15:37 +0530 Subject: [PATCH 2/3] Add on_checkpoint_write_end to LambdaCallback --- src/lightning/pytorch/callbacks/lambda_function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index f04b2d777deb3..15f0e925a9fdb 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -67,6 +67,7 @@ def __init__( on_exception: Optional[Callable] = None, on_save_checkpoint: Optional[Callable] = None, on_load_checkpoint: Optional[Callable] = None, + on_checkpoint_write_end: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, on_before_optimizer_step: Optional[Callable] = None, From 343c1c3df197c7b826269a46135a4144df7127a3 Mon Sep 17 00:00:00 2001 From: Iruos8805 Date: Tue, 4 Nov 2025 18:33:08 +0530 Subject: [PATCH 3/3] update: refine on_checkpoint_write_end hook and tests --- src/lightning/pytorch/core/hooks.py | 24 +++++ src/lightning/pytorch/trainer/call.py | 23 ----- src/lightning/pytorch/trainer/trainer.py | 3 +- .../test_checkpoint_write_callback.py | 99 ------------------- .../test_checkpoint_write_end_callback.py | 41 ++++++++ .../trainer/logging_/test_logger_connector.py | 1 + 6 files changed, 68 insertions(+), 123 deletions(-) delete mode 100644 tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py create mode 100644 tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..24e4cc17b0423 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -709,3 +709,27 @@ def on_save_checkpoint(self, checkpoint): There is no need for you to store anything about training. """ + + def on_checkpoint_write_end(self, filepath: str) -> None: + r"""Called after a checkpoint file has been fully written to disk. + + This hook is triggered after the checkpoint saving process completes, + ensuring the file exists and is readable. Unlike :meth:`on_save_checkpoint`, + which is called before the checkpoint is written, this hook guarantees + the file is available on disk. + + Args: + filepath: Path to the checkpoint file that was written. + + Example:: + + class MyModel(LightningModule): + def on_checkpoint_write_end(self, filepath): + print(f"Checkpoint saved at: {filepath}") + upload_to_s3(filepath) + + Note: + In distributed training, this hook is called on all ranks after + the barrier synchronization completes. + + """ diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 81ddc3549a80b..77536cdc16b33 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -295,29 +295,6 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_on_checkpoint_write_end(trainer: "pl.Trainer", filepath: str) -> None: - """Called after a checkpoint file is written, calls every callback's `on_checkpoint_write_end` hook.""" - - pl_module = trainer.lightning_module - if pl_module: - prev_fx_name = pl_module._current_fx_name - pl_module._current_fx_name = "on_checkpoint_write_end" - - for callback in trainer.callbacks: - try: - with trainer.profiler.profile(f"[Callback]{callback.state_key}.on_checkpoint_write_end"): - callback.on_checkpoint_write_end(trainer, trainer.lightning_module, filepath) - except (KeyboardInterrupt, SystemExit): - raise - except Exception as e: - rank_zero_warn( - f"Exception raised in `on_checkpoint_write_end` of callback `{callback.__class__.__name__}`: {e}", - ) - - if pl_module: - pl_module._current_fx_name = prev_fx_name - - def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 01a305ba897c0..81d56d03f3342 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1413,7 +1413,8 @@ def save_checkpoint( self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options) self.strategy.barrier("Trainer.save_checkpoint") - call._call_callbacks_on_checkpoint_write_end(self, filepath) + call._call_callback_hooks(self, "on_checkpoint_write_end", filepath) + call._call_lightning_module_hook(self, "on_checkpoint_write_end", filepath) """ State properties diff --git a/tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py b/tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py deleted file mode 100644 index a0d840e9de95c..0000000000000 --- a/tests/tests_pytorch/callbacks/test_checkpoint_write_callback.py +++ /dev/null @@ -1,99 +0,0 @@ -import os - -import torch - -from lightning.pytorch import Callback, Trainer -from lightning.pytorch.demos.boring_classes import BoringModel - - -class CheckpointWriteEndCallback(Callback): - def __init__(self): - self.called = False - self.filepath = None - self.file_existed = False - self.checkpoint_valid = False - - def on_checkpoint_write_end(self, trainer, pl_module, filepath): - """Test hook should trigger after checkpoint is written.""" - self.called = True - self.filepath = str(filepath) - self.file_existed = os.path.exists(filepath) - - try: - checkpoint = torch.load(filepath, map_location="cpu") - self.checkpoint_valid = "state_dict" in checkpoint - except Exception: - self.checkpoint_valid = False - - -def test_on_checkpoint_write_end_called(tmp_path): - """Test that on_checkpoint_write_end is triggered after checkpoint saving.""" - model = BoringModel() - callback = CheckpointWriteEndCallback() - trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False) - - trainer.fit(model) - checkpoint_path = tmp_path / "test_checkpoint.ckpt" - trainer.save_checkpoint(checkpoint_path) - - assert checkpoint_path.exists() - assert callback.called - assert callback.file_existed - assert callback.checkpoint_valid - assert callback.filepath == str(checkpoint_path) - - -def test_on_checkpoint_write_end_exception_safe(tmp_path): - """Test that callback exceptions don’t block others.""" - model = BoringModel() - - class FailingCallback(Callback): - def on_checkpoint_write_end(self, trainer, pl_module, filepath): - raise RuntimeError("Intentional error") - - class SuccessCallback(Callback): - def __init__(self): - self.called = False - - def on_checkpoint_write_end(self, trainer, pl_module, filepath): - self.called = True - - fail_cb = FailingCallback() - success_cb = SuccessCallback() - trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[fail_cb, success_cb], logger=False) - - trainer.fit(model) - checkpoint_path = tmp_path / "test_checkpoint.ckpt" - trainer.save_checkpoint(checkpoint_path) - - assert checkpoint_path.exists() - assert success_cb.called - - -def test_checkpoint_file_accessibility(tmp_path): - """Test that checkpoint is readable during callback execution.""" - model = BoringModel() - - class FileAccessCallback(Callback): - def __init__(self): - self.can_read = False - self.valid = False - - def on_checkpoint_write_end(self, trainer, pl_module, filepath): - try: - ckpt = torch.load(filepath, map_location="cpu") - self.can_read = True - self.valid = "state_dict" in ckpt - except (OSError, RuntimeError): - pass - - callback = FileAccessCallback() - trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False) - - trainer.fit(model) - checkpoint_path = tmp_path / "test_checkpoint.ckpt" - trainer.save_checkpoint(checkpoint_path) - - assert checkpoint_path.exists() - assert callback.can_read - assert callback.valid diff --git a/tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py b/tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py new file mode 100644 index 0000000000000..18968553aa746 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py @@ -0,0 +1,41 @@ +import os + +import torch + +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + + +class CheckpointWriteEndCallback(Callback): + def __init__(self): + self.called = False + self.filepath = None + self.file_existed = False + self.checkpoint_valid = False + + def on_checkpoint_write_end(self, trainer, pl_module, filepath): + """Verify that the hook triggers after checkpoint is written.""" + self.called = True + self.filepath = str(filepath) + self.file_existed = os.path.exists(filepath) + + checkpoint = torch.load(filepath, map_location="cpu") + self.checkpoint_valid = "state_dict" in checkpoint + + +def test_on_checkpoint_write_end_called(tmp_path): + """Test that on_checkpoint_write_end is called after saving a checkpoint.""" + model = BoringModel() + callback = CheckpointWriteEndCallback() + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False) + + trainer.fit(model) + + checkpoint_path = tmp_path / "test_checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + assert checkpoint_path.exists() + assert callback.called + assert callback.file_existed + assert callback.checkpoint_valid + assert callback.filepath == str(checkpoint_path) diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 6d893ba3b9d6b..b6573a5799c1b 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -88,6 +88,7 @@ def test_fx_validator(): "on_fit_start", "on_exception", "on_load_checkpoint", + "on_checkpoint_write_end", "load_state_dict", "on_sanity_check_end", "on_sanity_check_start",