Skip to content

Commit bd27dbf

Browse files
SkafteNickiBordadeependujha
authored
Add new args to BatchSizeFinder (#21163)
* new args to batch size scaler * add to tuner * add to callback * add testing * update * fix tests * Apply suggestion from @deependujha * Apply suggestion from @deependujha * Apply suggestion from @deependujha * Apply suggestion from @deependujha * update * safe default * update * add assertion * fix doc issue --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent a7cb33a commit bd27dbf

File tree

4 files changed

+137
-10
lines changed

4 files changed

+137
-10
lines changed

src/lightning/pytorch/callbacks/batch_size_finder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ class BatchSizeFinder(Callback):
6363
- ``model.hparams``
6464
- ``trainer.datamodule`` (the datamodule passed to the tune method)
6565
66+
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
67+
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
68+
max_val: Maximum batch size limit, defaults to 8192.
69+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
70+
when running on CPU or when automatic OOM detection is not available.
71+
6672
Example::
6773
6874
# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
@@ -118,17 +124,23 @@ def __init__(
118124
init_val: int = 2,
119125
max_trials: int = 25,
120126
batch_arg_name: str = "batch_size",
127+
margin: float = 0.05,
128+
max_val: int = 8192,
121129
) -> None:
122130
mode = mode.lower()
123131
if mode not in self.SUPPORTED_MODES:
124132
raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}")
125133

134+
assert 0.0 <= margin < 1.0, f"`margin` should be between 0 and 1. Found {margin=}"
135+
126136
self.optimal_batch_size: Optional[int] = init_val
127137
self._mode = mode
128138
self._steps_per_trial = steps_per_trial
129139
self._init_val = init_val
130140
self._max_trials = max_trials
131141
self._batch_arg_name = batch_arg_name
142+
self._margin = margin
143+
self._max_val = max_val
132144
self._early_exit = False
133145

134146
@override
@@ -180,6 +192,8 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
180192
self._init_val,
181193
self._max_trials,
182194
self._batch_arg_name,
195+
self._margin,
196+
self._max_val,
183197
)
184198

185199
self.optimal_batch_size = new_size

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def _scale_batch_size(
3232
init_val: int = 2,
3333
max_trials: int = 25,
3434
batch_arg_name: str = "batch_size",
35+
margin: float = 0.05,
36+
max_val: int = 8192,
3537
) -> Optional[int]:
3638
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
3739
error.
@@ -58,7 +60,15 @@ def _scale_batch_size(
5860
- ``model.hparams``
5961
- ``trainer.datamodule`` (the datamodule passed to the tune method)
6062
63+
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
64+
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
65+
max_val: Maximum batch size limit, defaults to 8192.
66+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
67+
when running on CPU or when automatic OOM detection is not available.
68+
6169
"""
70+
assert 0.0 <= margin < 1.0, f"`margin` should be between 0 and 1. Found {margin=}"
71+
6272
if trainer.fast_dev_run:
6373
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
6474
return None
@@ -80,9 +90,9 @@ def _scale_batch_size(
8090
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
8191

8292
if mode == "power":
83-
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
93+
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params, max_val)
8494
elif mode == "binsearch":
85-
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
95+
new_size = _run_binsearch_scaling(trainer, new_size, batch_arg_name, max_trials, params, margin, max_val)
8696

8797
garbage_collection_cuda()
8898

@@ -173,6 +183,7 @@ def _run_power_scaling(
173183
batch_arg_name: str,
174184
max_trials: int,
175185
params: dict[str, Any],
186+
max_val: int = 8192,
176187
) -> int:
177188
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
178189
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -185,6 +196,10 @@ def _run_power_scaling(
185196
# reset after each try
186197
_reset_progress(trainer)
187198

199+
if new_size >= max_val:
200+
rank_zero_info(f"Reached the maximum batch size limit of {max_val}. Stopping search.")
201+
break
202+
188203
try:
189204
_try_loop_run(trainer, params)
190205
last_successful_size = new_size # Store the current size before doubling
@@ -217,18 +232,22 @@ def _run_power_scaling(
217232
return new_size
218233

219234

220-
def _run_binary_scaling(
235+
def _run_binsearch_scaling(
221236
trainer: "pl.Trainer",
222237
new_size: int,
223238
batch_arg_name: str,
224239
max_trials: int,
225240
params: dict[str, Any],
241+
margin: float,
242+
max_val: int = 8192,
226243
) -> int:
227244
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
228245
229246
Hereafter, the batch size is further refined using a binary search
230247
231248
"""
249+
assert 0.0 <= margin < 1.0, f"`margin` should be between 0 and 1. Found {margin=}"
250+
232251
low = 1
233252
high = None
234253
count = 0
@@ -239,6 +258,10 @@ def _run_binary_scaling(
239258
# reset after each try
240259
_reset_progress(trainer)
241260

261+
if new_size >= max_val:
262+
rank_zero_info(f"Reached the maximum batch size limit of {max_val}. Stopping search.")
263+
break
264+
242265
try:
243266
# run loop
244267
_try_loop_run(trainer, params)
@@ -256,9 +279,13 @@ def _run_binary_scaling(
256279
if high - low <= 1:
257280
break
258281
midval = (high + low) // 2
259-
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded")
282+
new_size, changed = _adjust_batch_size(
283+
trainer, batch_arg_name, value=midval, desc="succeeded", max_val=max_val
284+
)
260285
else:
261-
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
286+
new_size, changed = _adjust_batch_size(
287+
trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val
288+
)
262289

263290
if not changed:
264291
break
@@ -284,6 +311,17 @@ def _run_binary_scaling(
284311
else:
285312
raise # some other error not memory related
286313

314+
# Apply margin reduction for binsearch mode
315+
if margin > 0:
316+
margin_reduced_size = max(1, int(new_size * (1 - margin)))
317+
if margin_reduced_size != new_size:
318+
rank_zero_info(
319+
f"Applying margin of {margin:.1%}, reducing batch size from {new_size} to {margin_reduced_size}"
320+
)
321+
new_size = margin_reduced_size
322+
# propagate the reduced batch size to the model/datamodule attribute
323+
lightning_setattr(trainer.lightning_module, batch_arg_name, new_size)
324+
287325
return new_size
288326

289327

@@ -293,6 +331,7 @@ def _adjust_batch_size(
293331
factor: float = 1.0,
294332
value: Optional[int] = None,
295333
desc: Optional[str] = None,
334+
max_val: int = 8192,
296335
) -> tuple[int, bool]:
297336
"""Helper function for adjusting the batch size.
298337
@@ -303,6 +342,9 @@ def _adjust_batch_size(
303342
value: if a value is given, will override the batch size with this value.
304343
Note that the value of `factor` will not have an effect in this case
305344
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
345+
max_val: Maximum batch size limit, defaults to 8192.
346+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
347+
when running on CPU or when automatic OOM detection is not available.
306348
307349
Returns:
308350
The new batch size for the next trial and a bool that signals whether the
@@ -321,13 +363,22 @@ def _adjust_batch_size(
321363
try:
322364
combined_dataset_length = combined_loader._dataset_length()
323365
if batch_size >= combined_dataset_length:
324-
rank_zero_info(f"The batch size {batch_size} is greater or equal than the length of your dataset.")
366+
rank_zero_info(
367+
f"The batch size {batch_size} is greater or equal than"
368+
f" the length of your dataset: {combined_dataset_length}."
369+
)
325370
return batch_size, False
326371
except NotImplementedError:
327372
# all datasets are iterable style
328373
pass
329374

330375
new_size = value if value is not None else int(batch_size * factor)
376+
377+
# Apply max_val limit if provided
378+
if new_size > max_val:
379+
if desc:
380+
rank_zero_info(f"Batch size {new_size} exceeds max_val limit {max_val}, capping at {max_val}")
381+
new_size = max_val
331382
if desc:
332383
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
333384
changed = new_size != batch_size

src/lightning/pytorch/tuner/tuning.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def scale_batch_size(
4141
init_val: int = 2,
4242
max_trials: int = 25,
4343
batch_arg_name: str = "batch_size",
44+
margin: float = 0.05,
45+
max_val: int = 8192,
4446
) -> Optional[int]:
4547
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
4648
error.
@@ -75,9 +77,16 @@ def scale_batch_size(
7577
- ``model.hparams``
7678
- ``trainer.datamodule`` (the datamodule passed to the tune method)
7779
80+
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
81+
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
82+
max_val: Maximum batch size limit, defaults to 8192.
83+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
84+
when running on CPU or when automatic OOM detection is not available.
85+
7886
"""
7987
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
8088
_check_scale_batch_size_configuration(self._trainer)
89+
assert 0.0 <= margin < 1.0, f"`margin` should be between 0 and 1. Found {margin=}"
8190

8291
# local import to avoid circular import
8392
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
@@ -88,6 +97,8 @@ def scale_batch_size(
8897
init_val=init_val,
8998
max_trials=max_trials,
9099
batch_arg_name=batch_arg_name,
100+
margin=margin,
101+
max_val=max_val,
91102
)
92103
# do not continue with the loop in case Tuner is used
93104
batch_size_finder._early_exit = True

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import glob
1515
import logging
16+
import math
1617
import os
1718
from copy import deepcopy
1819
from unittest.mock import patch
@@ -69,7 +70,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm
6970

7071
tuner = Tuner(trainer)
7172
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
72-
assert new_batch_size == 8
73+
assert new_batch_size == 7 # applied margin of 5% on 8 -> int(8 * 0.95) = 7
7374

7475
if model_bs is not None:
7576
assert model.batch_size == new_batch_size
@@ -317,7 +318,12 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
317318
# With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages
318319
expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0
319320
assert caplog.text.count("trying batch size") == expected_tries
320-
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)
321+
322+
# Determine the largest batch size that was actually tested.
323+
# For "power" this is the final found size; for "binsearch" we applied a 5% margin
324+
# when storing the final value, so the largest tested value is the one before applying margin.
325+
largest_tested_batch_size = new_batch_size if scale_method == "power" else int(math.ceil(new_batch_size * 100 / 95))
326+
assert caplog.text.count("greater or equal than the length") == int(largest_tested_batch_size >= dataset_len)
321327

322328
assert trainer.train_dataloader.batch_size == new_batch_size
323329
assert trainer.val_dataloaders.batch_size == new_batch_size
@@ -453,7 +459,7 @@ def val_dataloader(self):
453459
tuner.scale_batch_size(model, method="validate")
454460

455461

456-
@pytest.mark.parametrize(("scale_method", "expected_batch_size"), [("power", 62), ("binsearch", 100)])
462+
@pytest.mark.parametrize(("scale_method", "expected_batch_size"), [("power", 62), ("binsearch", 95)])
457463
@patch("lightning.pytorch.tuner.batch_size_scaling.is_oom_error", return_value=True)
458464
def test_dataloader_batch_size_updated_on_failure(_, tmp_path, scale_method, expected_batch_size):
459465
class CustomBatchSizeModel(BatchSizeModel):
@@ -493,6 +499,51 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
493499
assert trainer.num_val_batches[0] != steps_per_trial
494500

495501

502+
@pytest.mark.parametrize("margin", [0.0, 0.1, 0.2])
503+
def test_scale_batch_size_margin_and_max_val(tmp_path, margin):
504+
"""Test margin feature for batch size scaling by comparing results with and without margin."""
505+
# First, find the batch size without margin
506+
model1 = BatchSizeModel(batch_size=2)
507+
trainer1 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False)
508+
tuner1 = Tuner(trainer1)
509+
510+
result_without_margin = tuner1.scale_batch_size(
511+
model1, mode="binsearch", max_trials=2, steps_per_trial=1, margin=0.0
512+
)
513+
514+
model2 = BatchSizeModel(batch_size=2)
515+
trainer2 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False)
516+
tuner2 = Tuner(trainer2)
517+
518+
result_with_margin = tuner2.scale_batch_size(
519+
model2, mode="binsearch", max_trials=2, steps_per_trial=1, margin=margin
520+
)
521+
522+
assert result_without_margin is not None
523+
assert result_with_margin is not None
524+
525+
if margin == 0.0:
526+
assert result_with_margin == result_without_margin
527+
else:
528+
expected_with_margin = max(1, int(result_without_margin * (1 - margin)))
529+
assert result_with_margin == expected_with_margin
530+
assert result_with_margin <= result_without_margin
531+
532+
533+
@pytest.mark.parametrize("mode", ["power", "binsearch"])
534+
def test_scale_batch_size_max_val_limit(tmp_path, mode):
535+
"""Test that max_val limits the batch size for both power and binsearch modes."""
536+
model = BatchSizeModel(batch_size=2)
537+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
538+
tuner = Tuner(trainer)
539+
540+
max_val = 8 # Set a low max value
541+
result = tuner.scale_batch_size(model, mode=mode, max_trials=5, steps_per_trial=1, max_val=max_val)
542+
543+
assert result is not None
544+
assert result <= max_val
545+
546+
496547
def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
497548
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""
498549

@@ -566,7 +617,7 @@ def training_step(self, batch, batch_idx):
566617
("max_trials", "mode", "init_val", "expected"),
567618
[
568619
(3, "power", 2, 8),
569-
(3, "binsearch", 2, 8),
620+
(3, "binsearch", 2, 7), # applied margin of 5% on 8 -> int(8 * 0.95) = 7
570621
(1, "power", 4, 4),
571622
(0, "power", 2, 2),
572623
],

0 commit comments

Comments
 (0)