Skip to content

Commit fa4003b

Browse files
SkafteNickiBordadeependujha
authored
Docs finetuning callback example (#21216)
* add to documentation * improve docstring * fix docs * skip doctest if torchvision missing * missing space * changelog * fix mistake * Empty-Commit * empty commit to rerun ci * try fixing doctest --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: jirka <jirka.borovec@seznam.cz> Co-authored-by: Deependu Jha <deependujha21@gmail.com>
1 parent 74b3fd5 commit fa4003b

File tree

2 files changed

+175
-2
lines changed

2 files changed

+175
-2
lines changed

docs/source-pytorch/advanced/transfer_learning.rst

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,135 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
126126
h_cls = h[:, 0]
127127
logits = self.W(h_cls)
128128
return logits, attn
129+
130+
----
131+
132+
***********************************
133+
Automated Finetuning with Callbacks
134+
***********************************
135+
136+
PyTorch Lightning provides the :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback to automate
137+
the finetuning process. This callback gradually unfreezes your model's backbone during training. This is particularly
138+
useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
139+
then progressively unfreeze layers to fine-tune the model.
140+
141+
The :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback expects your model to have a specific structure:
142+
143+
.. testcode::
144+
145+
class MyModel(LightningModule):
146+
def __init__(self):
147+
super().__init__()
148+
149+
# REQUIRED: Your model must have a 'backbone' attribute
150+
# This should be the pretrained part you want to finetune
151+
self.backbone = some_pretrained_model
152+
153+
# Your task-specific layers (head, classifier, etc.)
154+
self.head = nn.Linear(backbone_features, num_classes)
155+
156+
def configure_optimizers(self):
157+
# Only optimize the head initially - backbone will be added automatically
158+
return torch.optim.Adam(self.head.parameters(), lr=1e-3)
159+
160+
************************************
161+
Example: Computer Vision with ResNet
162+
************************************
163+
164+
Here's a complete example showing how to use :class:`~lightning.pytorch.callbacks.BackboneFinetuning`
165+
for computer vision:
166+
167+
.. code-block:: python
168+
169+
import torch
170+
import torch.nn as nn
171+
import torchvision.models as models
172+
from lightning.pytorch import LightningModule, Trainer
173+
from lightning.pytorch.callbacks import BackboneFinetuning
174+
175+
176+
class ResNetClassifier(LightningModule):
177+
def __init__(self, num_classes=10, learning_rate=1e-3):
178+
super().__init__()
179+
self.save_hyperparameters()
180+
181+
# Create backbone from pretrained ResNet
182+
resnet = models.resnet50(weights="DEFAULT")
183+
# Remove the final classification layer
184+
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
185+
186+
# Add custom classification head
187+
self.head = nn.Sequential(
188+
nn.Flatten(),
189+
nn.Linear(resnet.fc.in_features, 512),
190+
nn.ReLU(),
191+
nn.Dropout(0.2),
192+
nn.Linear(512, num_classes)
193+
)
194+
195+
def forward(self, x):
196+
# Extract features with backbone
197+
features = self.backbone(x)
198+
# Classify with head
199+
return self.head(features)
200+
201+
def training_step(self, batch, batch_idx):
202+
x, y = batch
203+
y_hat = self(x)
204+
loss = nn.functional.cross_entropy(y_hat, y)
205+
self.log('train_loss', loss)
206+
return loss
207+
208+
def configure_optimizers(self):
209+
# Initially only train the head - backbone will be added by callback
210+
return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)
211+
212+
213+
# Setup the finetuning callback
214+
backbone_finetuning = BackboneFinetuning(
215+
unfreeze_backbone_at_epoch=10, # Start unfreezing backbone at epoch 10
216+
lambda_func=lambda epoch: 1.5, # Gradually increase backbone learning rate
217+
backbone_initial_ratio_lr=0.1, # Backbone starts at 10% of head learning rate
218+
should_align=True, # Align rates when backbone rate reaches head rate
219+
verbose=True # Print learning rates during training
220+
)
221+
222+
model = ResNetClassifier()
223+
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)
224+
225+
****************************
226+
Custom Finetuning Strategies
227+
****************************
228+
229+
For more control, you can create custom finetuning strategies by subclassing
230+
:class:`~lightning.pytorch.callbacks.BaseFinetuning`:
231+
232+
.. testcode::
233+
234+
from lightning.pytorch.callbacks.finetuning import BaseFinetuning
235+
236+
237+
class CustomFinetuning(BaseFinetuning):
238+
def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
239+
super().__init__()
240+
self.unfreeze_at_epoch = unfreeze_at_epoch
241+
self.layers_per_epoch = layers_per_epoch
242+
243+
def freeze_before_training(self, pl_module):
244+
# Freeze the entire backbone initially
245+
self.freeze(pl_module.backbone)
246+
247+
def finetune_function(self, pl_module, epoch, optimizer):
248+
# Gradually unfreeze layers
249+
if epoch >= self.unfreeze_at_epoch:
250+
layers_to_unfreeze = min(
251+
self.layers_per_epoch,
252+
len(list(pl_module.backbone.children()))
253+
)
254+
255+
# Unfreeze from the top layers down
256+
backbone_children = list(pl_module.backbone.children())
257+
for layer in backbone_children[-layers_to_unfreeze:]:
258+
self.unfreeze_and_add_param_group(
259+
layer, optimizer, lr=1e-4
260+
)

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,13 @@
3131
import lightning.pytorch as pl
3232
from lightning.pytorch.callbacks.callback import Callback
3333
from lightning.pytorch.utilities.exceptions import MisconfigurationException
34+
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
3435
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
3536

37+
if not _TORCHVISION_AVAILABLE:
38+
__doctest_skip__ = ["BackboneFinetuning"]
39+
40+
3641
log = logging.getLogger(__name__)
3742

3843

@@ -356,10 +361,46 @@ class BackboneFinetuning(BaseFinetuning):
356361
357362
Example::
358363
359-
>>> from lightning.pytorch import Trainer
364+
>>> import torch
365+
>>> import torch.nn as nn
366+
>>> from lightning.pytorch import LightningModule, Trainer
360367
>>> from lightning.pytorch.callbacks import BackboneFinetuning
368+
>>> import torchvision.models as models
369+
>>>
370+
>>> class TransferLearningModel(LightningModule):
371+
... def __init__(self, num_classes=10):
372+
... super().__init__()
373+
... # REQUIRED: Your model must have a 'backbone' attribute
374+
... self.backbone = models.resnet50(weights=None)
375+
... # Remove the final classification layer from backbone
376+
... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
377+
...
378+
... # Add your task-specific head
379+
... self.head = nn.Sequential(
380+
... nn.Flatten(),
381+
... nn.Linear(2048, 512),
382+
... nn.ReLU(),
383+
... nn.Linear(512, num_classes)
384+
... )
385+
...
386+
... def forward(self, x):
387+
... # Extract features with backbone
388+
... features = self.backbone(x)
389+
... # Classify with head
390+
... return self.head(features)
391+
...
392+
... def configure_optimizers(self):
393+
... # Initially only optimize the head - backbone will be added by callback
394+
... return torch.optim.Adam(self.head.parameters(), lr=1e-3)
395+
...
396+
>>> # Setup the callback
361397
>>> multiplicative = lambda epoch: 1.5
362-
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
398+
>>> backbone_finetuning = BackboneFinetuning(
399+
... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10
400+
... lambda_func=multiplicative, # Gradually increase backbone LR
401+
... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR
402+
... )
403+
>>> model = TransferLearningModel()
363404
>>> trainer = Trainer(callbacks=[backbone_finetuning])
364405
365406
"""

0 commit comments

Comments
 (0)