Skip to content

Conversation

Copy link

Copilot AI commented Oct 4, 2025

What does this PR do?

Fixes a MisconfigurationException that occurs when using ReduceLROnPlateau scheduler with check_val_every_n_epoch > 1. The scheduler was attempting to access validation metrics on epochs where validation didn't run, causing an error.

Issue

When check_val_every_n_epoch is set to a value greater than 1, validation only runs on specific epochs (e.g., every 2nd epoch). However, the ReduceLROnPlateau scheduler was being updated at the end of every epoch, attempting to access the monitored metric (e.g., val/loss) even when it wasn't available.

Example error:

MisconfigurationException: ReduceLROnPlateau conditioned on metric val/loss which is not available. 
Available metrics are: ['lr-AdamW/pg1', 'lr-AdamW/pg2', 'train/a_pcc', 'train/loss']. 
Condition can be set using `monitor` key in lr scheduler dict

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal changes)

Solution

Modified the scheduler update logic in fit_loop.py to only update plateau schedulers when validation actually runs. This ensures the monitored metrics are available when the scheduler needs them.

Before:

self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)

After:

if (
    not self.restarting
    and self.epoch_loop._num_ready_batches_reached()
    and self.epoch_loop._should_check_val_epoch()
):
    self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)

Testing

Added test case test_reducelronplateau_with_check_val_every_n_epoch that verifies the fix works correctly when validation runs every N epochs.

trainer = Trainer(
    max_epochs=3,
    check_val_every_n_epoch=2,  # Validation only on epochs 0, 2
)
trainer.fit(model)  # No longer raises MisconfigurationException

Impact

  • Behavior change: Plateau schedulers are now only updated on epochs when validation runs
  • Non-plateau schedulers: Unchanged behavior - still updated every epoch
  • Backward compatibility: Maintained - default behavior (validation every epoch) works as before
  • User workaround: Users no longer need to set "strict": False to avoid the error

Fixes #<issue_number>

Original prompt

This section details on the original issue you should resolve

<issue_title>ReduceLROnPlateu within configure_optimizers behave abnormally</issue_title>
<issue_description>### Bug description

Got error

  File "c:\Users\sean\miniconda3\envs\keras+torch+pl\Lib\site-packages\lightning\pytorch\loops\training_epoch_loop.py", line 459, in _update_learning_rates
    raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: ReduceLROnPlateau conditioned on metric val/loss which is not available. Available metrics are: ['lr-AdamW/pg1', 'lr-AdamW/pg2', 'train/a_pcc', 'train/loss']. Condition can be set using `monitor` key in lr scheduler dict

Here is the configure_optimizers function:

    @final
    def configure_optimizers(self):

        decay, no_decay = [], []
        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue
            if "bias" in name or "Norm" in name:
                no_decay.append(param)
            else:
                decay.append(param)

        grouped_params = [
            {"params": decay, "weight_decay": self.weight_decay, "lr": self.lr * 0.3},
            {
                "params": no_decay,
                "weight_decay": self.weight_decay,
                "lr": self.lr * 1.7,
            },
        ]

        optimizer = self.optmizer_class(
            grouped_params, lr=self.lr, weight_decay=self.weight_decay
        )

        scheduler = self.lr_scheduler_class(
            optimizer, **self.lr_scheduler_args if self.lr_scheduler_args else {}
        )
        scheduler = {
            "scheduler": self.lr_scheduler_class(
                optimizer, **self.lr_scheduler_args if self.lr_scheduler_args else {}
            ),
            "monitor": "val/loss",
            "interval": "epoch",
            "frequency": 1,
            # "strict": False,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

The lr_scheduler_class is passed in as

  lr_scheduler_class: torch.optim.lr_scheduler.ReduceLROnPlateau
  lr_scheduler_args:
    mode: min
    factor: 0.5
    patience: 10
    threshold: 0.0001
    threshold_mode: rel
    cooldown: 5
    min_lr: 1.e-9
    eps: 1.e-08

(using yaml and CLI, which, I think, is not the case here)

It seems that I got the error at the end of the training epoch, as I just see the progress bar reports train/loss. The validation epoch is not finished, but the scheduler is called.

I am quite sure that val/loss is available after validation epoch is finished, because progress bar can correctly display it.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

StatusCode : 200
StatusDescription : OK
Content : # Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the...
RawContent : HTTP/1.1 200 OK
Connection: keep-alive
Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox
Strict-Transport-Security: max-age=31536000
X-Content-Type-Options: nosniff
...
Forms : {}
Headers : {[Connection, keep-alive], [Content-Security-Policy, default-src 'none'; style-src 'unsafe-inline'; sandbox], [Strict-Transport-Security, max-age=31536000],
[X-Content-Type-Options, nosniff]...}
Images : {}
InputFields : {}
Links : {}
ParsedHtml : mshtml.HTMLDocumentClass
RawContentLength : 2775

More info

No response</issue_description>

Comments on the Issue (you are @copilot in this section)

Fixes #20829

✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.


📚 Documentation preview 📚: https://pytorch-lightning--21266.org.readthedocs.build/en/21266/

Copilot AI and others added 2 commits October 4, 2025 21:22
- Only update plateau schedulers on epochs when validation runs
- This prevents errors when monitored metrics are not available
- Added test case for this scenario

Co-authored-by: Borda <6035284+Borda@users.noreply.github.com>
Co-authored-by: Borda <6035284+Borda@users.noreply.github.com>
Copilot AI changed the title [WIP] ReduceLROnPlateu within configure_optimizers behave abnormally Fix ReduceLROnPlateau scheduler error when validation doesn't run every epoch Oct 4, 2025
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Oct 4, 2025
Copilot AI requested a review from Borda October 4, 2025 21:29
@Borda Borda marked this pull request as ready for review November 11, 2025 09:46
@github-actions
Copy link
Contributor

github-actions bot commented Nov 11, 2025

⛈️ Required checks status: Has failure 🔴

Warning
This job will need to be re-run to merge your PR. If you do not have write access to the repository, you can ask Lightning-AI/lai-frameworks to re-run it. If you push a new commit, all of CI will re-trigger.

Groups summary

🔴 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu-guardian failure

These checks are required after the changes to src/lightning/pytorch/loops/fit_loop.py, tests/tests_pytorch/trainer/optimization/test_optimizers.py.

🟢 pytorch_lightning: lit GPU
Check ID Status
pytorch.yml / Lit Job (nvidia/cuda:12.1.1-devel-ubuntu22.04, pytorch, 3.10) success
pytorch.yml / Lit Job (lightning, 3.12) success
pytorch.yml / Lit Job (pytorch, 3.12) success

These checks are required after the changes to src/lightning/pytorch/loops/fit_loop.py, tests/tests_pytorch/trainer/optimization/test_optimizers.py.

🟢 Benchmarks
Check ID Status
benchmark.yml / Lit Job (fabric) success
benchmark.yml / Lit Job (pytorch) success

These checks are required after the changes to src/lightning/pytorch/loops/fit_loop.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/loops/fit_loop.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/pytorch/loops/fit_loop.py.

🟢 install
Check ID Status
install-pkg-guardian success

These checks are required after the changes to src/lightning/pytorch/loops/fit_loop.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 70 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@codecov
Copy link

codecov bot commented Nov 11, 2025

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
3244 1 3243 520
View the full list of 1 ❄️ flaky test(s)
tests/tests_pytorch/checkpointing/test_model_checkpoint.py::test_model_checkpoint_score_and_ckpt[True-True-False-train_log_epoch]

Flake rate in main: 4.00% (Passed 24 times, Failed 1 times)

Stack Traces | 0.077s run time
tmp_path = PosixPath('.../pytest-of-runner/pytest-0/test_model_checkpoint_score_an5')
validation_step_none = True, val_dataloaders_none = False
monitor = 'train_log_epoch', reduce_lr_on_plateau = True

    @pytest.mark.parametrize(
        ("validation_step_none", "val_dataloaders_none", "monitor"),
        [(False, False, "val_log"), (True, False, "train_log_epoch"), (False, True, "val_log")],
    )
    @pytest.mark.parametrize("reduce_lr_on_plateau", [False, True])
    def test_model_checkpoint_score_and_ckpt(
        tmp_path, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool
    ):
        """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint
        data."""
        max_epochs = 3
        limit_train_batches = 5
        limit_val_batches = 7
        lr, gamma = 1e-1, 2
    
        class CustomBoringModel(BoringModel):
            def __init__(self):
                super().__init__()
                self.train_log_epochs = torch.randn(max_epochs, limit_train_batches)
                self.val_logs = torch.randn(max_epochs, limit_val_batches)
                self.scores = []
    
            def training_step(self, batch, batch_idx):
                log_value = self.train_log_epochs[self.current_epoch, batch_idx]
                self.log("train_log", log_value, on_epoch=True)
                return super().training_step(batch, batch_idx)
    
            def validation_step(self, batch, batch_idx):
                log_value = self.val_logs[self.current_epoch, batch_idx]
                self.log("val_log", log_value)
                return super().validation_step(batch, batch_idx)
    
            def configure_optimizers(self):
                optimizer = optim.SGD(self.parameters(), lr=lr)
    
                if reduce_lr_on_plateau:
                    lr_scheduler = {
                        "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer),
                        "monitor": monitor,
                        "strict": True,
                    }
                else:
                    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
    
                return [optimizer], [lr_scheduler]
    
            def on_train_epoch_end(self):
                if "train" in monitor:
                    self.scores.append(self.trainer.logged_metrics[monitor])
    
            def on_validation_epoch_end(self):
                if not self.trainer.sanity_checking and "val" in monitor:
                    self.scores.append(self.trainer.logged_metrics[monitor])
    
        filename = "{" + f"{monitor}" + ":.4f}-{epoch}"
        checkpoint = ModelCheckpoint(dirpath=tmp_path, filename=filename, monitor=monitor, save_top_k=-1)
    
        model = CustomBoringModel()
    
        if validation_step_none:
            model.validation_step = None
        if val_dataloaders_none:
            model.val_dataloaders = None
    
        trainer = Trainer(
            default_root_dir=tmp_path,
            callbacks=[checkpoint],
            limit_train_batches=limit_train_batches,
            limit_val_batches=limit_val_batches,
            max_epochs=max_epochs,
            enable_progress_bar=False,
        )
        calls = mock_training_epoch_loop(trainer)
        trainer.fit(model)
    
        ckpt_files = list(tmp_path.glob("*.ckpt"))
        assert len(ckpt_files) == len(model.scores) == max_epochs
    
        for epoch in range(max_epochs):
            score = model.scores[epoch]
            expected_score = getattr(model, f"{monitor}s")[epoch].mean().item()
            assert math.isclose(score, expected_score, abs_tol=1e-5)
    
            expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
            chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
            assert chk["epoch"] == epoch
            assert chk["global_step"] == limit_train_batches * (epoch + 1)
    
            mc_specific_data = chk["callbacks"][
                f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
                " 'train_time_interval': None}"
            ]
            assert mc_specific_data["dirpath"] == checkpoint.dirpath
            assert mc_specific_data["monitor"] == monitor
            assert mc_specific_data["current_score"] == score
    
            if not reduce_lr_on_plateau:
                actual_step_count = chk["lr_schedulers"][0]["_step_count"]
                actual_lr = chk["lr_schedulers"][0]["_last_lr"][0]
                # checkpoint is saved after updating lr_scheduler states
                assert actual_step_count == epoch + 2  # step_count starts at 1
                assert actual_lr == lr * gamma ** (epoch + 1)
            else:
>               assert calls[epoch] == {monitor: score}
                       ^^^^^^^^^^^^
E               KeyError: 0

checkpointing/test_model_checkpoint.py:185: KeyError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ReduceLROnPlateu within configure_optimizers behave abnormally

1 participant