Skip to content

Commit af68031

Browse files
authored
Add AdamW to CPUOffloadOptimizer default (#742)
add default
1 parent eb47c93 commit af68031

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

torchao/prototype/low_bit_optim/cpu_offload.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
from typing import Type
22

33
import torch
4-
from torch.optim.optimizer import Optimizer
4+
from torch.optim.optimizer import Optimizer, ParamsT
5+
6+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
57

68

79
class CPUOffloadOptimizer:
8-
def __init__(self, params, optimizer_class: Type[Optimizer], *, offload_gradients: bool = False, **kwargs) -> None:
10+
def __init__(
11+
self,
12+
params: ParamsT,
13+
optimizer_class: Type[Optimizer] = torch.optim.AdamW,
14+
*,
15+
offload_gradients: bool = False,
16+
**kwargs,
17+
) -> None:
918
"""Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
1019
Optimizer step will be done on CPU.
1120
1221
Args
1322
params: a list of parameters or parameter groups.
14-
optimizer_class: constructor of the base optimizer.
23+
optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`.
1524
offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation.
1625
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
1726
"""
27+
# default to fused CPU AdamW
28+
if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs:
29+
kwargs.update(fused=True)
30+
1831
param_groups = list(params)
1932
if len(param_groups) == 0:
2033
raise ValueError("optimizer got an empty parameter list")

0 commit comments

Comments
 (0)