Skip to content

Commit b82db78

Browse files
SkafteNickipre-commit-ci[bot]justusschockBordadeependujha
authored
Fix last.ckpt only being saved when another checkpoint has been created (#21244)
* fix implementation * add testing * smaller fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changelog * empty ci commit * fix tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent fa4003b commit b82db78

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5252
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))
5353

5454

55+
- Fixed `last.ckpt` being created and not linked to another checkpoint ([#21244](https://github.com/Lightning-AI/pytorch-lightning/pull/21244))
56+
57+
5558
- Fixed bug that prevented `BackboneFinetuning` from being used together with `LearningRateFinder` ([#21224](https://github.com/Lightning-AI/pytorch-lightning/pull/21224))
5659

5760

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
380380
monitor_candidates = self._monitor_candidates(trainer)
381381
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
382382
self._save_topk_checkpoint(trainer, monitor_candidates)
383-
self._save_last_checkpoint(trainer, monitor_candidates)
383+
# Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link"
384+
if self._last_global_step_saved == trainer.global_step or (
385+
self.save_last == "link" and self._last_checkpoint_saved
386+
):
387+
self._save_last_checkpoint(trainer, monitor_candidates)
384388

385389
@override
386390
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -397,7 +401,11 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
397401

398402
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
399403
self._save_topk_checkpoint(trainer, monitor_candidates)
400-
self._save_last_checkpoint(trainer, monitor_candidates)
404+
# Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link"
405+
if self._last_global_step_saved == trainer.global_step or (
406+
self.save_last == "link" and self._last_checkpoint_saved
407+
):
408+
self._save_last_checkpoint(trainer, monitor_candidates)
401409

402410
@override
403411
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1948,7 +1948,7 @@ def test_save_last_every_n_epochs_interaction(tmp_path, every_n_epochs):
19481948
with patch.object(trainer, "save_checkpoint") as save_mock:
19491949
trainer.fit(model)
19501950
assert mc.last_model_path # a "last" ckpt was saved
1951-
assert save_mock.call_count == trainer.max_epochs
1951+
assert save_mock.call_count == trainer.max_epochs - 1
19521952

19531953

19541954
def test_train_epoch_end_ckpt_with_no_validation():
@@ -2124,3 +2124,59 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
21242124

21252125
# save_last=True should always save last.ckpt
21262126
assert (tmp_path / "last.ckpt").exists()
2127+
2128+
2129+
def test_save_last_only_when_checkpoint_saved(tmp_path):
2130+
"""Test that save_last only creates last.ckpt when another checkpoint is actually saved."""
2131+
2132+
class SelectiveModel(BoringModel):
2133+
def __init__(self):
2134+
super().__init__()
2135+
self.validation_step_outputs = []
2136+
2137+
def validation_step(self, batch, batch_idx):
2138+
outputs = super().validation_step(batch, batch_idx)
2139+
epoch = self.trainer.current_epoch
2140+
loss = torch.tensor(1.0 - epoch * 0.1) if epoch % 2 == 0 else torch.tensor(1.0 + epoch * 0.1)
2141+
outputs["val_loss"] = loss
2142+
self.validation_step_outputs.append(outputs)
2143+
return outputs
2144+
2145+
def on_validation_epoch_end(self):
2146+
if self.validation_step_outputs:
2147+
avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean()
2148+
self.log("val_loss", avg_loss)
2149+
self.validation_step_outputs.clear()
2150+
2151+
model = SelectiveModel()
2152+
2153+
checkpoint_callback = ModelCheckpoint(
2154+
dirpath=tmp_path,
2155+
filename="best-{epoch}-{val_loss:.2f}",
2156+
monitor="val_loss",
2157+
save_last=True,
2158+
save_top_k=1,
2159+
mode="min",
2160+
every_n_epochs=1,
2161+
save_on_train_epoch_end=False,
2162+
)
2163+
2164+
trainer = Trainer(
2165+
max_epochs=4,
2166+
callbacks=[checkpoint_callback],
2167+
logger=False,
2168+
enable_progress_bar=False,
2169+
limit_train_batches=2,
2170+
limit_val_batches=2,
2171+
enable_checkpointing=True,
2172+
)
2173+
2174+
trainer.fit(model)
2175+
2176+
checkpoint_files = list(tmp_path.glob("*.ckpt"))
2177+
checkpoint_names = [f.name for f in checkpoint_files]
2178+
assert "last.ckpt" in checkpoint_names, "last.ckpt should exist since checkpoints were saved"
2179+
expected_files = 2 # best checkpoint + last.ckpt
2180+
assert len(checkpoint_files) == expected_files, (
2181+
f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}"
2182+
)

0 commit comments

Comments
 (0)