From 64155998019ae5cf2bbb37b68a8e45c9897c20c9 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 2 Oct 2025 15:26:06 -0700 Subject: [PATCH 01/13] Add `EMAWeightAveraging` callback to `weight_averaging.py` --- .../pytorch/callbacks/weight_averaging.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index f9b8d64eae6a5..c97f7c5b41f7b 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -361,3 +361,59 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) for average_param, current_param in zip(average_params, current_params): current_param.data.copy_(average_param.data) + + +class EMAWeightAveraging(WeightAveraging): + """Exponential Moving Average (EMA) Weight Averaging callback.""" + + def __init__( + self, + device: Optional[Union[torch.device, str, int]] = None, + use_buffers: bool = True, + decay: float = 0.999, + update_every_n_steps: int = 1, + update_starting_at_step: Optional[int] = None, + update_starting_at_epoch: Optional[int] = None, + **kwargs: Any, + ): + super().__init__( + device=device, + use_buffers=use_buffers, + **kwargs, + avg_fn=get_ema_avg_fn(decay=decay), + ) + + self.update_every_n_steps = update_every_n_steps + self.update_starting_at_step = update_starting_at_step + self.update_starting_at_epoch = update_starting_at_epoch + + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None): + """Decide when to update the model weights. + + Args: + step_idx: The current step index. + epoch_idx: The current epoch index. + Returns: + bool: True if the model weights should be updated, False otherwise. + """ + if step_idx is not None: + # Check step-based conditions only if we have a valid step_idx + meets_step_requirement = ( + self.update_starting_at_step is None or step_idx >= self.update_starting_at_step + ) + meets_step_frequency = ( + self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 + ) + if meets_step_requirement and meets_step_frequency: + return True + + if epoch_idx is not None: + # Check epoch-based condition only if we specify one + meets_epoch_requirement = ( + self.update_starting_at_epoch is not None + and epoch_idx >= self.update_starting_at_epoch + ) + if meets_epoch_requirement: + return True + + return False From bcf746f05ab03de0a11bcfcadf890bdb185c28c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Oct 2025 22:27:37 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/weight_averaging.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index c97f7c5b41f7b..673ed8fae2f0d 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -395,23 +395,19 @@ def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] epoch_idx: The current epoch index. Returns: bool: True if the model weights should be updated, False otherwise. + """ if step_idx is not None: # Check step-based conditions only if we have a valid step_idx - meets_step_requirement = ( - self.update_starting_at_step is None or step_idx >= self.update_starting_at_step - ) - meets_step_frequency = ( - self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 - ) + meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step + meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 if meets_step_requirement and meets_step_frequency: return True if epoch_idx is not None: # Check epoch-based condition only if we specify one meets_epoch_requirement = ( - self.update_starting_at_epoch is not None - and epoch_idx >= self.update_starting_at_epoch + self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch ) if meets_epoch_requirement: return True From b9501980d67dc5a184381104921f0173ed5d51a3 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 2 Oct 2025 15:30:33 -0700 Subject: [PATCH 03/13] Update CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1bba5e4ca0da7..30072517a2e9a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) +- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) From 7c940ab358cd8c1e365261c656e0ab961df02641 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 3 Oct 2025 16:16:50 -0700 Subject: [PATCH 04/13] Update weight_averaging.py --- src/lightning/pytorch/callbacks/weight_averaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 673ed8fae2f0d..c6f95adaedc1a 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -21,7 +21,7 @@ from typing import Any, Optional, Union import torch -from torch.optim.swa_utils import AveragedModel +from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn from typing_extensions import override import lightning.pytorch as pl From 6daf1feccdacb55369151df367a8a09afdb2bc32 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 17 Oct 2025 14:17:09 -0700 Subject: [PATCH 05/13] Update test_weight_averaging.py --- .../callbacks/test_weight_averaging.py | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index ec230b2fd6c97..c58edebcbe88b 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -329,3 +329,142 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices callback = EMATestCallback(devices=devices) _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) return model + +@pytest.mark.parametrize( + ("strategy", "accelerator", "devices"), + [ + ("auto", "cpu", 1), + pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), + ], +) +def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices): + """Test EMAWeightAveraging callback with various update configurations.""" + from weight_averaging import EMAWeightAveraging + + model = TestModel() + dataset = RandomDataset(32, 32) + + # Test with default settings (update every step) + callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1) + _train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices) + + # Verify the average model was created and updated + assert callback._average_model is not None + assert callback._average_model.n_averaged > 0 + + +def test_ema_weight_averaging_step_frequency(tmp_path): + """Test EMAWeightAveraging with custom step update frequency.""" + from weight_averaging import EMAWeightAveraging + + model = TestModel() + dataset = RandomDataset(32, 32) + + # Update every 5 steps + callback = EMAWeightAveraging(decay=0.95, update_every_n_steps=5) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None + + +def test_ema_weight_averaging_starting_step(tmp_path): + """Test EMAWeightAveraging with delayed start based on steps.""" + from weight_averaging import EMAWeightAveraging + + model = TestModel() + dataset = RandomDataset(32, 32) + + # Start updating after step 10 + callback = EMAWeightAveraging( + decay=0.999, + update_every_n_steps=1, + update_starting_at_step=10 + ) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None + + +def test_ema_weight_averaging_starting_epoch(tmp_path): + """Test EMAWeightAveraging with delayed start based on epochs.""" + from weight_averaging import EMAWeightAveraging + + model = TestModel() + dataset = RandomDataset(32, 32) + + # Start updating after epoch 3 + callback = EMAWeightAveraging( + decay=0.999, + update_every_n_steps=1, + update_starting_at_epoch=3 + ) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None + + +def test_ema_weight_averaging_should_update(tmp_path): + """Test the should_update logic of EMAWeightAveraging.""" + from weight_averaging import EMAWeightAveraging + + # Test with step-based updates + callback = EMAWeightAveraging( + update_every_n_steps=5, + update_starting_at_step=10 + ) + + # Before starting step + assert not callback.should_update(step_idx=5) + assert not callback.should_update(step_idx=9) + + # At and after starting step, but not on update frequency + assert callback.should_update(step_idx=10) # First update + assert not callback.should_update(step_idx=11) + assert not callback.should_update(step_idx=14) + assert callback.should_update(step_idx=15) # Second update + + # Test with epoch-based updates + callback = EMAWeightAveraging( + update_starting_at_epoch=2 + ) + + assert not callback.should_update(epoch_idx=0) + assert not callback.should_update(epoch_idx=1) + assert callback.should_update(epoch_idx=2) + assert callback.should_update(epoch_idx=3) + + +def test_ema_weight_averaging_checkpoint_save_load(tmp_path): + """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" + from weight_averaging import EMAWeightAveraging + + model = TestModel() + dataset = RandomDataset(32, 32) + + callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) + + # Train and create checkpoint + _train(model, dataset, tmp_path, callback, will_crash=True) + + # Resume from checkpoint + model2 = TestModel() + callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) + checkpoint_path = str(tmp_path / "lightning_logs" / "version_0" / "checkpoints" / "*.ckpt") + + _train(model2, dataset, tmp_path, callback2, checkpoint_path=checkpoint_path) + + assert callback2._average_model is not None + + +@pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999]) +def test_ema_weight_averaging_decay_values(tmp_path, decay): + """Test EMAWeightAveraging with different decay values.""" + from weight_averaging import EMAWeightAveraging + + model = TestModel() + dataset = RandomDataset(32, 32) + + callback = EMAWeightAveraging(decay=decay, update_every_n_steps=1) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None From 39ea0bbd72865451a9fc44683fdc49abcadc5bdc Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 17 Oct 2025 14:18:33 -0700 Subject: [PATCH 06/13] Update test_weight_averaging.py --- .../callbacks/test_weight_averaging.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index c58edebcbe88b..a35fe62adc766 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader, Dataset from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.callbacks import WeightAveraging +from lightning.pytorch.callbacks import EMAWeightAveraging, WeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from tests_pytorch.helpers.runif import RunIf @@ -330,6 +330,7 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) return model + @pytest.mark.parametrize( ("strategy", "accelerator", "devices"), [ @@ -338,9 +339,7 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices ], ) def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices): - """Test EMAWeightAveraging callback with various update configurations.""" - from weight_averaging import EMAWeightAveraging - + """Test EMAWeightAveraging callback with various update configurations.""" model = TestModel() dataset = RandomDataset(32, 32) @@ -354,9 +353,7 @@ def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices): def test_ema_weight_averaging_step_frequency(tmp_path): - """Test EMAWeightAveraging with custom step update frequency.""" - from weight_averaging import EMAWeightAveraging - + """Test EMAWeightAveraging with custom step update frequency.""" model = TestModel() dataset = RandomDataset(32, 32) @@ -368,9 +365,7 @@ def test_ema_weight_averaging_step_frequency(tmp_path): def test_ema_weight_averaging_starting_step(tmp_path): - """Test EMAWeightAveraging with delayed start based on steps.""" - from weight_averaging import EMAWeightAveraging - + """Test EMAWeightAveraging with delayed start based on steps.""" model = TestModel() dataset = RandomDataset(32, 32) @@ -386,9 +381,7 @@ def test_ema_weight_averaging_starting_step(tmp_path): def test_ema_weight_averaging_starting_epoch(tmp_path): - """Test EMAWeightAveraging with delayed start based on epochs.""" - from weight_averaging import EMAWeightAveraging - + """Test EMAWeightAveraging with delayed start based on epochs.""" model = TestModel() dataset = RandomDataset(32, 32) @@ -404,9 +397,7 @@ def test_ema_weight_averaging_starting_epoch(tmp_path): def test_ema_weight_averaging_should_update(tmp_path): - """Test the should_update logic of EMAWeightAveraging.""" - from weight_averaging import EMAWeightAveraging - + """Test the should_update logic of EMAWeightAveraging.""" # Test with step-based updates callback = EMAWeightAveraging( update_every_n_steps=5, @@ -435,9 +426,7 @@ def test_ema_weight_averaging_should_update(tmp_path): def test_ema_weight_averaging_checkpoint_save_load(tmp_path): - """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" - from weight_averaging import EMAWeightAveraging - + """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" model = TestModel() dataset = RandomDataset(32, 32) @@ -458,9 +447,7 @@ def test_ema_weight_averaging_checkpoint_save_load(tmp_path): @pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999]) def test_ema_weight_averaging_decay_values(tmp_path, decay): - """Test EMAWeightAveraging with different decay values.""" - from weight_averaging import EMAWeightAveraging - + """Test EMAWeightAveraging with different decay values.""" model = TestModel() dataset = RandomDataset(32, 32) From 0a062e875a98184bacd25048b38add7c9336a11c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Oct 2025 21:20:09 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../callbacks/test_weight_averaging.py | 73 ++++++++----------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index a35fe62adc766..dba630b5350de 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -339,86 +339,73 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices ], ) def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices): - """Test EMAWeightAveraging callback with various update configurations.""" + """Test EMAWeightAveraging callback with various update configurations.""" model = TestModel() dataset = RandomDataset(32, 32) - + # Test with default settings (update every step) callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1) _train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices) - + # Verify the average model was created and updated assert callback._average_model is not None assert callback._average_model.n_averaged > 0 def test_ema_weight_averaging_step_frequency(tmp_path): - """Test EMAWeightAveraging with custom step update frequency.""" + """Test EMAWeightAveraging with custom step update frequency.""" model = TestModel() dataset = RandomDataset(32, 32) - + # Update every 5 steps callback = EMAWeightAveraging(decay=0.95, update_every_n_steps=5) _train(model, dataset, tmp_path, callback) - + assert callback._average_model is not None def test_ema_weight_averaging_starting_step(tmp_path): - """Test EMAWeightAveraging with delayed start based on steps.""" + """Test EMAWeightAveraging with delayed start based on steps.""" model = TestModel() dataset = RandomDataset(32, 32) - + # Start updating after step 10 - callback = EMAWeightAveraging( - decay=0.999, - update_every_n_steps=1, - update_starting_at_step=10 - ) + callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_step=10) _train(model, dataset, tmp_path, callback) - + assert callback._average_model is not None def test_ema_weight_averaging_starting_epoch(tmp_path): - """Test EMAWeightAveraging with delayed start based on epochs.""" + """Test EMAWeightAveraging with delayed start based on epochs.""" model = TestModel() dataset = RandomDataset(32, 32) - + # Start updating after epoch 3 - callback = EMAWeightAveraging( - decay=0.999, - update_every_n_steps=1, - update_starting_at_epoch=3 - ) + callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_epoch=3) _train(model, dataset, tmp_path, callback) - + assert callback._average_model is not None def test_ema_weight_averaging_should_update(tmp_path): - """Test the should_update logic of EMAWeightAveraging.""" + """Test the should_update logic of EMAWeightAveraging.""" # Test with step-based updates - callback = EMAWeightAveraging( - update_every_n_steps=5, - update_starting_at_step=10 - ) - + callback = EMAWeightAveraging(update_every_n_steps=5, update_starting_at_step=10) + # Before starting step assert not callback.should_update(step_idx=5) assert not callback.should_update(step_idx=9) - + # At and after starting step, but not on update frequency assert callback.should_update(step_idx=10) # First update assert not callback.should_update(step_idx=11) assert not callback.should_update(step_idx=14) assert callback.should_update(step_idx=15) # Second update - + # Test with epoch-based updates - callback = EMAWeightAveraging( - update_starting_at_epoch=2 - ) - + callback = EMAWeightAveraging(update_starting_at_epoch=2) + assert not callback.should_update(epoch_idx=0) assert not callback.should_update(epoch_idx=1) assert callback.should_update(epoch_idx=2) @@ -426,32 +413,32 @@ def test_ema_weight_averaging_should_update(tmp_path): def test_ema_weight_averaging_checkpoint_save_load(tmp_path): - """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" + """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" model = TestModel() dataset = RandomDataset(32, 32) - + callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) - + # Train and create checkpoint _train(model, dataset, tmp_path, callback, will_crash=True) - + # Resume from checkpoint model2 = TestModel() callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) checkpoint_path = str(tmp_path / "lightning_logs" / "version_0" / "checkpoints" / "*.ckpt") - + _train(model2, dataset, tmp_path, callback2, checkpoint_path=checkpoint_path) - + assert callback2._average_model is not None @pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999]) def test_ema_weight_averaging_decay_values(tmp_path, decay): - """Test EMAWeightAveraging with different decay values.""" + """Test EMAWeightAveraging with different decay values.""" model = TestModel() dataset = RandomDataset(32, 32) - + callback = EMAWeightAveraging(decay=decay, update_every_n_steps=1) _train(model, dataset, tmp_path, callback) - + assert callback._average_model is not None From 9ab1eb4732ea84451f03842152b4fd006c78b5c1 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 17 Oct 2025 16:58:48 -0700 Subject: [PATCH 08/13] Update weight_averaging.py --- src/lightning/pytorch/callbacks/weight_averaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index c6f95adaedc1a..2e626bb331d04 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -387,7 +387,7 @@ def __init__( self.update_starting_at_step = update_starting_at_step self.update_starting_at_epoch = update_starting_at_epoch - def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None): + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> False: """Decide when to update the model weights. Args: From 3d3323da802a4f60bd7835d74e1e246a607c7523 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 18 Oct 2025 13:29:21 -0700 Subject: [PATCH 09/13] Update weight_averaging.py --- src/lightning/pytorch/callbacks/weight_averaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 2e626bb331d04..0640efed3d87b 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -387,7 +387,7 @@ def __init__( self.update_starting_at_step = update_starting_at_step self.update_starting_at_epoch = update_starting_at_epoch - def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> False: + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: """Decide when to update the model weights. Args: From d0ad1453becd53d1cd96103b1ec1f1f884b19a23 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sun, 19 Oct 2025 18:38:55 -0700 Subject: [PATCH 10/13] Update __init__.py --- src/lightning/pytorch/callbacks/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index d0ffb7b6a990c..dd96c045d8366 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -32,7 +32,7 @@ from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor from lightning.pytorch.callbacks.timer import Timer -from lightning.pytorch.callbacks.weight_averaging import WeightAveraging +from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging __all__ = [ "BackboneFinetuning", @@ -59,5 +59,6 @@ "ThroughputMonitor", "Timer", "TQDMProgressBar", + "EMAWeightAveraging", "WeightAveraging", ] From da5276b2d735e21346d7d72237480f3b40863f02 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 23 Oct 2025 09:51:33 -0700 Subject: [PATCH 11/13] Update tests/tests_pytorch/callbacks/test_weight_averaging.py Co-authored-by: GdoongMathew --- tests/tests_pytorch/callbacks/test_weight_averaging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index dba630b5350de..7448fb0f5cfb0 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -415,6 +415,7 @@ def test_ema_weight_averaging_should_update(tmp_path): def test_ema_weight_averaging_checkpoint_save_load(tmp_path): """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" model = TestModel() + model.crash_on_epoch = 2 dataset = RandomDataset(32, 32) callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) From 78d403701b2504623e36d0ede4addce419cc2660 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 24 Oct 2025 09:24:35 -0700 Subject: [PATCH 12/13] Update tests/tests_pytorch/callbacks/test_weight_averaging.py Thanks @GdoongMathew! Co-authored-by: GdoongMathew --- tests/tests_pytorch/callbacks/test_weight_averaging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index 7448fb0f5cfb0..b76825dd7e59b 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -426,9 +426,9 @@ def test_ema_weight_averaging_checkpoint_save_load(tmp_path): # Resume from checkpoint model2 = TestModel() callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) - checkpoint_path = str(tmp_path / "lightning_logs" / "version_0" / "checkpoints" / "*.ckpt") - - _train(model2, dataset, tmp_path, callback2, checkpoint_path=checkpoint_path) + import glob # should be at the top + _train(model2, dataset, tmp_path, callback2, + checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0]) assert callback2._average_model is not None From d4bb25f368f6356b784f0f3201a9c25e69579e6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:24:55 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests_pytorch/callbacks/test_weight_averaging.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index b76825dd7e59b..cfb066f023af0 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -426,9 +426,15 @@ def test_ema_weight_averaging_checkpoint_save_load(tmp_path): # Resume from checkpoint model2 = TestModel() callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) - import glob # should be at the top - _train(model2, dataset, tmp_path, callback2, - checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0]) + import glob # should be at the top + + _train( + model2, + dataset, + tmp_path, + callback2, + checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0], + ) assert callback2._average_model is not None