@@ -32,6 +32,8 @@ def _scale_batch_size(
3232 init_val : int = 2 ,
3333 max_trials : int = 25 ,
3434 batch_arg_name : str = "batch_size" ,
35+ margin : float = 0.05 ,
36+ max_val : int = 8192 ,
3537) -> Optional [int ]:
3638 """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
3739 error.
@@ -58,7 +60,15 @@ def _scale_batch_size(
5860 - ``model.hparams``
5961 - ``trainer.datamodule`` (the datamodule passed to the tune method)
6062
63+ margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
64+ 'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
65+ max_val: Maximum batch size limit, defaults to 8192.
66+ Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
67+ when running on CPU or when automatic OOM detection is not available.
68+
6169 """
70+ assert 0.0 <= margin < 1.0 , f"`margin` should be between 0 and 1. Found { margin = } "
71+
6272 if trainer .fast_dev_run :
6373 rank_zero_warn ("Skipping batch size scaler since `fast_dev_run` is enabled." )
6474 return None
@@ -80,9 +90,9 @@ def _scale_batch_size(
8090 new_size , _ = _adjust_batch_size (trainer , batch_arg_name , value = init_val )
8191
8292 if mode == "power" :
83- new_size = _run_power_scaling (trainer , new_size , batch_arg_name , max_trials , params )
93+ new_size = _run_power_scaling (trainer , new_size , batch_arg_name , max_trials , params , max_val )
8494 elif mode == "binsearch" :
85- new_size = _run_binary_scaling (trainer , new_size , batch_arg_name , max_trials , params )
95+ new_size = _run_binsearch_scaling (trainer , new_size , batch_arg_name , max_trials , params , margin , max_val )
8696
8797 garbage_collection_cuda ()
8898
@@ -173,6 +183,7 @@ def _run_power_scaling(
173183 batch_arg_name : str ,
174184 max_trials : int ,
175185 params : dict [str , Any ],
186+ max_val : int = 8192 ,
176187) -> int :
177188 """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
178189 # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -185,6 +196,10 @@ def _run_power_scaling(
185196 # reset after each try
186197 _reset_progress (trainer )
187198
199+ if new_size >= max_val :
200+ rank_zero_info (f"Reached the maximum batch size limit of { max_val } . Stopping search." )
201+ break
202+
188203 try :
189204 _try_loop_run (trainer , params )
190205 last_successful_size = new_size # Store the current size before doubling
@@ -217,18 +232,22 @@ def _run_power_scaling(
217232 return new_size
218233
219234
220- def _run_binary_scaling (
235+ def _run_binsearch_scaling (
221236 trainer : "pl.Trainer" ,
222237 new_size : int ,
223238 batch_arg_name : str ,
224239 max_trials : int ,
225240 params : dict [str , Any ],
241+ margin : float ,
242+ max_val : int = 8192 ,
226243) -> int :
227244 """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
228245
229246 Hereafter, the batch size is further refined using a binary search
230247
231248 """
249+ assert 0.0 <= margin < 1.0 , f"`margin` should be between 0 and 1. Found { margin = } "
250+
232251 low = 1
233252 high = None
234253 count = 0
@@ -239,6 +258,10 @@ def _run_binary_scaling(
239258 # reset after each try
240259 _reset_progress (trainer )
241260
261+ if new_size >= max_val :
262+ rank_zero_info (f"Reached the maximum batch size limit of { max_val } . Stopping search." )
263+ break
264+
242265 try :
243266 # run loop
244267 _try_loop_run (trainer , params )
@@ -256,9 +279,13 @@ def _run_binary_scaling(
256279 if high - low <= 1 :
257280 break
258281 midval = (high + low ) // 2
259- new_size , changed = _adjust_batch_size (trainer , batch_arg_name , value = midval , desc = "succeeded" )
282+ new_size , changed = _adjust_batch_size (
283+ trainer , batch_arg_name , value = midval , desc = "succeeded" , max_val = max_val
284+ )
260285 else :
261- new_size , changed = _adjust_batch_size (trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" )
286+ new_size , changed = _adjust_batch_size (
287+ trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" , max_val = max_val
288+ )
262289
263290 if not changed :
264291 break
@@ -284,6 +311,17 @@ def _run_binary_scaling(
284311 else :
285312 raise # some other error not memory related
286313
314+ # Apply margin reduction for binsearch mode
315+ if margin > 0 :
316+ margin_reduced_size = max (1 , int (new_size * (1 - margin )))
317+ if margin_reduced_size != new_size :
318+ rank_zero_info (
319+ f"Applying margin of { margin :.1%} , reducing batch size from { new_size } to { margin_reduced_size } "
320+ )
321+ new_size = margin_reduced_size
322+ # propagate the reduced batch size to the model/datamodule attribute
323+ lightning_setattr (trainer .lightning_module , batch_arg_name , new_size )
324+
287325 return new_size
288326
289327
@@ -293,6 +331,7 @@ def _adjust_batch_size(
293331 factor : float = 1.0 ,
294332 value : Optional [int ] = None ,
295333 desc : Optional [str ] = None ,
334+ max_val : int = 8192 ,
296335) -> tuple [int , bool ]:
297336 """Helper function for adjusting the batch size.
298337
@@ -303,6 +342,9 @@ def _adjust_batch_size(
303342 value: if a value is given, will override the batch size with this value.
304343 Note that the value of `factor` will not have an effect in this case
305344 desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
345+ max_val: Maximum batch size limit, defaults to 8192.
346+ Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
347+ when running on CPU or when automatic OOM detection is not available.
306348
307349 Returns:
308350 The new batch size for the next trial and a bool that signals whether the
@@ -321,13 +363,22 @@ def _adjust_batch_size(
321363 try :
322364 combined_dataset_length = combined_loader ._dataset_length ()
323365 if batch_size >= combined_dataset_length :
324- rank_zero_info (f"The batch size { batch_size } is greater or equal than the length of your dataset." )
366+ rank_zero_info (
367+ f"The batch size { batch_size } is greater or equal than"
368+ f" the length of your dataset: { combined_dataset_length } ."
369+ )
325370 return batch_size , False
326371 except NotImplementedError :
327372 # all datasets are iterable style
328373 pass
329374
330375 new_size = value if value is not None else int (batch_size * factor )
376+
377+ # Apply max_val limit if provided
378+ if new_size > max_val :
379+ if desc :
380+ rank_zero_info (f"Batch size { new_size } exceeds max_val limit { max_val } , capping at { max_val } " )
381+ new_size = max_val
331382 if desc :
332383 rank_zero_info (f"Batch size { batch_size } { desc } , trying batch size { new_size } " )
333384 changed = new_size != batch_size
0 commit comments