Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))


- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))


### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging

__all__ = [
"BackboneFinetuning",
Expand All @@ -59,5 +59,6 @@
"ThroughputMonitor",
"Timer",
"TQDMProgressBar",
"EMAWeightAveraging",
"WeightAveraging",
]
54 changes: 53 additions & 1 deletion src/lightning/pytorch/callbacks/weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Optional, Union

import torch
from torch.optim.swa_utils import AveragedModel
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
from typing_extensions import override

import lightning.pytorch as pl
Expand Down Expand Up @@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)


class EMAWeightAveraging(WeightAveraging):
"""Exponential Moving Average (EMA) Weight Averaging callback."""

def __init__(
self,
device: Optional[Union[torch.device, str, int]] = None,
use_buffers: bool = True,
decay: float = 0.999,
update_every_n_steps: int = 1,
update_starting_at_step: Optional[int] = None,
update_starting_at_epoch: Optional[int] = None,
**kwargs: Any,
):
super().__init__(
device=device,
use_buffers=use_buffers,
**kwargs,
avg_fn=get_ema_avg_fn(decay=decay),
)

self.update_every_n_steps = update_every_n_steps
self.update_starting_at_step = update_starting_at_step
self.update_starting_at_epoch = update_starting_at_epoch

def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
"""Decide when to update the model weights.

Args:
step_idx: The current step index.
epoch_idx: The current epoch index.
Returns:
bool: True if the model weights should be updated, False otherwise.

"""
if step_idx is not None:
# Check step-based conditions only if we have a valid step_idx
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
if meets_step_requirement and meets_step_frequency:
return True

if epoch_idx is not None:
# Check epoch-based condition only if we specify one
meets_epoch_requirement = (
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
)
if meets_epoch_requirement:
return True

return False
122 changes: 121 additions & 1 deletion tests/tests_pytorch/callbacks/test_weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import WeightAveraging
from lightning.pytorch.callbacks import EMAWeightAveraging, WeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -329,3 +329,123 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices
callback = EMATestCallback(devices=devices)
_train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs)
return model


@pytest.mark.parametrize(
("strategy", "accelerator", "devices"),
[
("auto", "cpu", 1),
pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)),
],
)
def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices):
"""Test EMAWeightAveraging callback with various update configurations."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Test with default settings (update every step)
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1)
_train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices)

# Verify the average model was created and updated
assert callback._average_model is not None
assert callback._average_model.n_averaged > 0


def test_ema_weight_averaging_step_frequency(tmp_path):
"""Test EMAWeightAveraging with custom step update frequency."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Update every 5 steps
callback = EMAWeightAveraging(decay=0.95, update_every_n_steps=5)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None


def test_ema_weight_averaging_starting_step(tmp_path):
"""Test EMAWeightAveraging with delayed start based on steps."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Start updating after step 10
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_step=10)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None


def test_ema_weight_averaging_starting_epoch(tmp_path):
"""Test EMAWeightAveraging with delayed start based on epochs."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Start updating after epoch 3
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_epoch=3)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None


def test_ema_weight_averaging_should_update(tmp_path):
"""Test the should_update logic of EMAWeightAveraging."""
# Test with step-based updates
callback = EMAWeightAveraging(update_every_n_steps=5, update_starting_at_step=10)

# Before starting step
assert not callback.should_update(step_idx=5)
assert not callback.should_update(step_idx=9)

# At and after starting step, but not on update frequency
assert callback.should_update(step_idx=10) # First update
assert not callback.should_update(step_idx=11)
assert not callback.should_update(step_idx=14)
assert callback.should_update(step_idx=15) # Second update

# Test with epoch-based updates
callback = EMAWeightAveraging(update_starting_at_epoch=2)

assert not callback.should_update(epoch_idx=0)
assert not callback.should_update(epoch_idx=1)
assert callback.should_update(epoch_idx=2)
assert callback.should_update(epoch_idx=3)


def test_ema_weight_averaging_checkpoint_save_load(tmp_path):
"""Test that EMAWeightAveraging correctly saves and loads checkpoints."""
model = TestModel()
model.crash_on_epoch = 2
dataset = RandomDataset(32, 32)

callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)

# Train and create checkpoint
_train(model, dataset, tmp_path, callback, will_crash=True)

# Resume from checkpoint
model2 = TestModel()
callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)
import glob # should be at the top

_train(
model2,
dataset,
tmp_path,
callback2,
checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0],
)

assert callback2._average_model is not None


@pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999])
def test_ema_weight_averaging_decay_values(tmp_path, decay):
"""Test EMAWeightAveraging with different decay values."""
model = TestModel()
dataset = RandomDataset(32, 32)

callback = EMAWeightAveraging(decay=decay, update_every_n_steps=1)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None
Loading