@@ -126,3 +126,135 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
126126 h_cls = h[:, 0]
127127 logits = self.W(h_cls)
128128 return logits, attn
129+
130+ ----
131+
132+ ***********************************
133+ Automated Finetuning with Callbacks
134+ ***********************************
135+
136+ PyTorch Lightning provides the :class: `~lightning.pytorch.callbacks.BackboneFinetuning ` callback to automate
137+ the finetuning process. This callback gradually unfreezes your model's backbone during training. This is particularly
138+ useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
139+ then progressively unfreeze layers to fine-tune the model.
140+
141+ The :class: `~lightning.pytorch.callbacks.BackboneFinetuning ` callback expects your model to have a specific structure:
142+
143+ .. testcode ::
144+
145+ class MyModel(LightningModule):
146+ def __init__(self):
147+ super().__init__()
148+
149+ # REQUIRED: Your model must have a 'backbone' attribute
150+ # This should be the pretrained part you want to finetune
151+ self.backbone = some_pretrained_model
152+
153+ # Your task-specific layers (head, classifier, etc.)
154+ self.head = nn.Linear(backbone_features, num_classes)
155+
156+ def configure_optimizers(self):
157+ # Only optimize the head initially - backbone will be added automatically
158+ return torch.optim.Adam(self.head.parameters(), lr=1e-3)
159+
160+ ************************************
161+ Example: Computer Vision with ResNet
162+ ************************************
163+
164+ Here's a complete example showing how to use :class: `~lightning.pytorch.callbacks.BackboneFinetuning `
165+ for computer vision:
166+
167+ .. code-block :: python
168+
169+ import torch
170+ import torch.nn as nn
171+ import torchvision.models as models
172+ from lightning.pytorch import LightningModule, Trainer
173+ from lightning.pytorch.callbacks import BackboneFinetuning
174+
175+
176+ class ResNetClassifier (LightningModule ):
177+ def __init__ (self , num_classes = 10 , learning_rate = 1e-3 ):
178+ super ().__init__ ()
179+ self .save_hyperparameters()
180+
181+ # Create backbone from pretrained ResNet
182+ resnet = models.resnet50(weights = " DEFAULT" )
183+ # Remove the final classification layer
184+ self .backbone = nn.Sequential(* list (resnet.children())[:- 1 ])
185+
186+ # Add custom classification head
187+ self .head = nn.Sequential(
188+ nn.Flatten(),
189+ nn.Linear(resnet.fc.in_features, 512 ),
190+ nn.ReLU(),
191+ nn.Dropout(0.2 ),
192+ nn.Linear(512 , num_classes)
193+ )
194+
195+ def forward (self , x ):
196+ # Extract features with backbone
197+ features = self .backbone(x)
198+ # Classify with head
199+ return self .head(features)
200+
201+ def training_step (self , batch , batch_idx ):
202+ x, y = batch
203+ y_hat = self (x)
204+ loss = nn.functional.cross_entropy(y_hat, y)
205+ self .log(' train_loss' , loss)
206+ return loss
207+
208+ def configure_optimizers (self ):
209+ # Initially only train the head - backbone will be added by callback
210+ return torch.optim.Adam(self .head.parameters(), lr = self .hparams.learning_rate)
211+
212+
213+ # Setup the finetuning callback
214+ backbone_finetuning = BackboneFinetuning(
215+ unfreeze_backbone_at_epoch = 10 , # Start unfreezing backbone at epoch 10
216+ lambda_func = lambda epoch : 1.5 , # Gradually increase backbone learning rate
217+ backbone_initial_ratio_lr = 0.1 , # Backbone starts at 10% of head learning rate
218+ should_align = True , # Align rates when backbone rate reaches head rate
219+ verbose = True # Print learning rates during training
220+ )
221+
222+ model = ResNetClassifier()
223+ trainer = Trainer(callbacks = [backbone_finetuning], max_epochs = 20 )
224+
225+ ****************************
226+ Custom Finetuning Strategies
227+ ****************************
228+
229+ For more control, you can create custom finetuning strategies by subclassing
230+ :class: `~lightning.pytorch.callbacks.BaseFinetuning `:
231+
232+ .. testcode ::
233+
234+ from lightning.pytorch.callbacks.finetuning import BaseFinetuning
235+
236+
237+ class CustomFinetuning(BaseFinetuning):
238+ def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
239+ super().__init__()
240+ self.unfreeze_at_epoch = unfreeze_at_epoch
241+ self.layers_per_epoch = layers_per_epoch
242+
243+ def freeze_before_training(self, pl_module):
244+ # Freeze the entire backbone initially
245+ self.freeze(pl_module.backbone)
246+
247+ def finetune_function(self, pl_module, epoch, optimizer):
248+ # Gradually unfreeze layers
249+ if epoch >= self.unfreeze_at_epoch:
250+ layers_to_unfreeze = min(
251+ self.layers_per_epoch,
252+ len(list(pl_module.backbone.children()))
253+ )
254+
255+ # Unfreeze from the top layers down
256+ backbone_children = list(pl_module.backbone.children())
257+ for layer in backbone_children[-layers_to_unfreeze:]:
258+ self.unfreeze_and_add_param_group(
259+ layer, optimizer, lr=1e-4
260+ )
0 commit comments