Skip to content

Commit b3e0df0

Browse files
SkafteNickiBordadeependujha
authored
Bugfix for BackboneFinetuning + LearningRateFinder (#21224)
* check for attribute before index * add testing * changelog * Empty-Commit --------- Co-authored-by: jirka <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent e561c2c commit b3e0df0

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5252
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))
5353

5454

55+
- Fixed bug that prevented `BackboneFinetuning` from being used together with `LearningRateFinder` ([#21224](https://github.com/Lightning-AI/pytorch-lightning/pull/21224))
56+
57+
5558
- Fixed `ModelPruning` sparsity logging bug that caused incorrect sparsity percentages ([#21223](https://github.com/Lightning-AI/pytorch-lightning/pull/21223))
5659

5760

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
108108
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
109109
# restore the param_groups created during the previous training.
110110
if self._restarting:
111-
named_parameters = dict(pl_module.named_parameters())
112-
for opt_idx, optimizer in enumerate(trainer.optimizers):
113-
param_groups = self._apply_mapping_to_param_groups(
114-
self._internal_optimizer_metadata[opt_idx], named_parameters
115-
)
116-
optimizer.param_groups = param_groups
111+
if self._internal_optimizer_metadata:
112+
named_parameters = dict(pl_module.named_parameters())
113+
for opt_idx, optimizer in enumerate(trainer.optimizers):
114+
if opt_idx in self._internal_optimizer_metadata:
115+
param_groups = self._apply_mapping_to_param_groups(
116+
self._internal_optimizer_metadata[opt_idx], named_parameters
117+
)
118+
optimizer.param_groups = param_groups
117119
self._restarting = False
118120

119121
@staticmethod

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from lightning.pytorch import Trainer, seed_everything
2727
from lightning.pytorch.callbacks import EarlyStopping
28+
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
2829
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
2930
from lightning.pytorch.demos.boring_classes import BoringModel
3031
from lightning.pytorch.tuner.lr_finder import _LRFinder
@@ -799,3 +800,47 @@ def configure_optimizers(self):
799800
assert len(lr_find_checkpoints) == 0, (
800801
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
801802
)
803+
804+
805+
def test_lr_finder_with_backbone_finetuning_callback(tmp_path):
806+
"""Test that lr_find works correctly with BackboneFinetuning callback."""
807+
808+
class ModelWithBackbone(BoringModel):
809+
def __init__(self):
810+
super().__init__()
811+
# Create a simple backbone-head architecture
812+
self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 16), torch.nn.ReLU(), torch.nn.Linear(16, 8))
813+
self.head = torch.nn.Linear(8, 2)
814+
self.learning_rate = 1e-3
815+
816+
def forward(self, x):
817+
backbone_features = self.backbone(x)
818+
return self.head(backbone_features)
819+
820+
def configure_optimizers(self):
821+
return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
822+
823+
model = ModelWithBackbone()
824+
backbone_finetuning = BackboneFinetuning(unfreeze_backbone_at_epoch=1)
825+
826+
trainer = Trainer(
827+
default_root_dir=tmp_path,
828+
max_epochs=3,
829+
enable_checkpointing=False,
830+
enable_progress_bar=False,
831+
enable_model_summary=False,
832+
logger=False,
833+
callbacks=[backbone_finetuning],
834+
)
835+
836+
tuner = Tuner(trainer)
837+
lr_finder = tuner.lr_find(model, num_training=5)
838+
839+
assert lr_finder is not None
840+
assert hasattr(lr_finder, "results")
841+
assert len(lr_finder.results) > 0
842+
trainer.fit(model)
843+
844+
# Check that backbone was unfrozen at the correct epoch
845+
for param in model.backbone.parameters():
846+
assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1"

0 commit comments

Comments
 (0)