@@ -1948,7 +1948,7 @@ def test_save_last_every_n_epochs_interaction(tmp_path, every_n_epochs):
19481948 with patch .object (trainer , "save_checkpoint" ) as save_mock :
19491949 trainer .fit (model )
19501950 assert mc .last_model_path # a "last" ckpt was saved
1951- assert save_mock .call_count == trainer .max_epochs
1951+ assert save_mock .call_count == trainer .max_epochs - 1
19521952
19531953
19541954def test_train_epoch_end_ckpt_with_no_validation ():
@@ -2124,3 +2124,59 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
21242124
21252125 # save_last=True should always save last.ckpt
21262126 assert (tmp_path / "last.ckpt" ).exists ()
2127+
2128+
2129+ def test_save_last_only_when_checkpoint_saved (tmp_path ):
2130+ """Test that save_last only creates last.ckpt when another checkpoint is actually saved."""
2131+
2132+ class SelectiveModel (BoringModel ):
2133+ def __init__ (self ):
2134+ super ().__init__ ()
2135+ self .validation_step_outputs = []
2136+
2137+ def validation_step (self , batch , batch_idx ):
2138+ outputs = super ().validation_step (batch , batch_idx )
2139+ epoch = self .trainer .current_epoch
2140+ loss = torch .tensor (1.0 - epoch * 0.1 ) if epoch % 2 == 0 else torch .tensor (1.0 + epoch * 0.1 )
2141+ outputs ["val_loss" ] = loss
2142+ self .validation_step_outputs .append (outputs )
2143+ return outputs
2144+
2145+ def on_validation_epoch_end (self ):
2146+ if self .validation_step_outputs :
2147+ avg_loss = torch .stack ([x ["val_loss" ] for x in self .validation_step_outputs ]).mean ()
2148+ self .log ("val_loss" , avg_loss )
2149+ self .validation_step_outputs .clear ()
2150+
2151+ model = SelectiveModel ()
2152+
2153+ checkpoint_callback = ModelCheckpoint (
2154+ dirpath = tmp_path ,
2155+ filename = "best-{epoch}-{val_loss:.2f}" ,
2156+ monitor = "val_loss" ,
2157+ save_last = True ,
2158+ save_top_k = 1 ,
2159+ mode = "min" ,
2160+ every_n_epochs = 1 ,
2161+ save_on_train_epoch_end = False ,
2162+ )
2163+
2164+ trainer = Trainer (
2165+ max_epochs = 4 ,
2166+ callbacks = [checkpoint_callback ],
2167+ logger = False ,
2168+ enable_progress_bar = False ,
2169+ limit_train_batches = 2 ,
2170+ limit_val_batches = 2 ,
2171+ enable_checkpointing = True ,
2172+ )
2173+
2174+ trainer .fit (model )
2175+
2176+ checkpoint_files = list (tmp_path .glob ("*.ckpt" ))
2177+ checkpoint_names = [f .name for f in checkpoint_files ]
2178+ assert "last.ckpt" in checkpoint_names , "last.ckpt should exist since checkpoints were saved"
2179+ expected_files = 2 # best checkpoint + last.ckpt
2180+ assert len (checkpoint_files ) == expected_files , (
2181+ f"Expected { expected_files } files, got { len (checkpoint_files )} : { checkpoint_names } "
2182+ )
0 commit comments