Skip to content

Commit f35d6ea

Browse files
committed
Add multi-tensor (foreach) version of Lion in style of upcoming PyTorch 2.0 optimizers
1 parent 709d5e0 commit f35d6ea

File tree

2 files changed

+175
-17
lines changed

2 files changed

+175
-17
lines changed

timm/optim/lion.py

Lines changed: 154 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,24 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818
# ==============================================================================
19+
from typing import List
20+
1921
import torch
2022
from torch.optim.optimizer import Optimizer
2123

2224

2325
class Lion(Optimizer):
2426
r"""Implements Lion algorithm."""
2527

26-
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
28+
def __init__(
29+
self,
30+
params,
31+
lr=1e-4,
32+
betas=(0.9, 0.99),
33+
weight_decay=0.0,
34+
maximize=False,
35+
foreach=None,
36+
):
2737
"""Initialize the hyperparameters.
2838
2939
Args:
@@ -41,9 +51,21 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
4151
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
4252
if not 0.0 <= betas[1] < 1.0:
4353
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
44-
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
54+
defaults = dict(
55+
lr=lr,
56+
betas=betas,
57+
weight_decay=weight_decay,
58+
foreach=foreach,
59+
maximize=maximize,
60+
)
4561
super().__init__(params, defaults)
4662

63+
def __setstate__(self, state):
64+
super().__setstate__(state)
65+
for group in self.param_groups:
66+
group.setdefault('maximize', False)
67+
group.setdefault('foreach', None)
68+
4769
@torch.no_grad()
4870
def step(self, closure=None):
4971
"""Performs a single optimization step.
@@ -61,27 +83,144 @@ def step(self, closure=None):
6183
loss = closure()
6284

6385
for group in self.param_groups:
86+
params_with_grad = []
87+
grads = []
88+
exp_avgs = []
89+
beta1, beta2 = group['betas']
90+
6491
for p in group['params']:
6592
if p.grad is None:
6693
continue
94+
params_with_grad.append(p)
95+
if p.grad.is_sparse:
96+
raise RuntimeError('Lion does not support sparse gradients')
97+
grads.append(p.grad)
6798

68-
# Perform stepweight decay
69-
p.data.mul_(1 - group['lr'] * group['weight_decay'])
70-
71-
grad = p.grad
7299
state = self.state[p]
100+
73101
# State initialization
74102
if len(state) == 0:
75-
# Exponential moving average of gradient values
76-
state['exp_avg'] = torch.zeros_like(p)
103+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
77104

78-
exp_avg = state['exp_avg']
79-
beta1, beta2 = group['betas']
105+
exp_avgs.append(state['exp_avg'])
80106

81-
# Weight update
82-
update = exp_avg * beta1 + grad * (1 - beta1)
83-
p.add_(torch.sign(update), alpha=-group['lr'])
84-
# Decay the momentum running average coefficient
85-
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
107+
lion(
108+
params_with_grad,
109+
grads,
110+
exp_avgs,
111+
beta1=beta1,
112+
beta2=beta2,
113+
lr=group['lr'],
114+
weight_decay=group['weight_decay'],
115+
maximize=group['maximize'],
116+
foreach=group['foreach'],
117+
)
86118

87119
return loss
120+
121+
122+
def lion(
123+
params: List[torch.Tensor],
124+
grads: List[torch.Tensor],
125+
exp_avgs: List[torch.Tensor],
126+
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
127+
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
128+
maximize: bool = False,
129+
foreach: bool = None,
130+
*,
131+
beta1: float,
132+
beta2: float,
133+
lr: float,
134+
weight_decay: float,
135+
):
136+
r"""Functional API that performs Lion algorithm computation.
137+
"""
138+
if foreach is None:
139+
# Placeholder for more complex foreach logic to be added when value is not set
140+
foreach = False
141+
142+
if foreach and torch.jit.is_scripting():
143+
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
144+
145+
if foreach and not torch.jit.is_scripting():
146+
func = _multi_tensor_lion
147+
else:
148+
func = _single_tensor_lion
149+
150+
func(
151+
params,
152+
grads,
153+
exp_avgs,
154+
beta1=beta1,
155+
beta2=beta2,
156+
lr=lr,
157+
weight_decay=weight_decay,
158+
maximize=maximize,
159+
)
160+
161+
162+
def _single_tensor_lion(
163+
params: List[torch.Tensor],
164+
grads: List[torch.Tensor],
165+
exp_avgs: List[torch.Tensor],
166+
*,
167+
beta1: float,
168+
beta2: float,
169+
lr: float,
170+
weight_decay: float,
171+
maximize: bool,
172+
):
173+
for i, param in enumerate(params):
174+
grad = grads[i] if not maximize else -grads[i]
175+
exp_avg = exp_avgs[i]
176+
177+
if torch.is_complex(param):
178+
grad = torch.view_as_real(grad)
179+
exp_avg = torch.view_as_real(exp_avg)
180+
param = torch.view_as_real(param)
181+
182+
# Perform stepweight decay
183+
param.mul_(1 - lr * weight_decay)
184+
185+
# Weight update
186+
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
187+
param.add_(torch.sign(update), alpha=-lr)
188+
189+
# Decay the momentum running average coefficient
190+
exp_avg.lerp_(grad, 1 - beta2)
191+
192+
193+
def _multi_tensor_lion(
194+
params: List[torch.Tensor],
195+
grads: List[torch.Tensor],
196+
exp_avgs: List[torch.Tensor],
197+
*,
198+
beta1: float,
199+
beta2: float,
200+
lr: float,
201+
weight_decay: float,
202+
maximize: bool,
203+
):
204+
if len(params) == 0:
205+
return
206+
207+
if maximize:
208+
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
209+
210+
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
211+
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
212+
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
213+
214+
# Perform stepweight decay
215+
torch._foreach_mul_(params, 1 - lr * weight_decay)
216+
217+
# Weight update
218+
updates = torch._foreach_mul(exp_avgs, beta1)
219+
torch._foreach_add_(updates, grads, alpha=1 - beta1)
220+
221+
updates = [u.sign() for u in updates]
222+
torch._foreach_add_(params, updates, alpha=-lr)
223+
224+
# Decay the momentum running average coefficient
225+
torch._foreach_mul_(exp_avgs, beta2)
226+
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2)

timm/optim/optim_factory.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
_logger = logging.getLogger(__name__)
3737

3838

39+
# optimizers to default to multi-tensor
40+
_DEFAULT_FOREACH = {
41+
'lion',
42+
}
43+
44+
3945
def param_groups_weight_decay(
4046
model: nn.Module,
4147
weight_decay=1e-5,
@@ -162,7 +168,8 @@ def optimizer_kwargs(cfg):
162168
opt=cfg.opt,
163169
lr=cfg.lr,
164170
weight_decay=cfg.weight_decay,
165-
momentum=cfg.momentum)
171+
momentum=cfg.momentum,
172+
)
166173
if getattr(cfg, 'opt_eps', None) is not None:
167174
kwargs['eps'] = cfg.opt_eps
168175
if getattr(cfg, 'opt_betas', None) is not None:
@@ -171,6 +178,8 @@ def optimizer_kwargs(cfg):
171178
kwargs['layer_decay'] = cfg.layer_decay
172179
if getattr(cfg, 'opt_args', None) is not None:
173180
kwargs.update(cfg.opt_args)
181+
if getattr(cfg, 'opt_foreach', None) is not None:
182+
kwargs['foreach'] = cfg.opt_foreach
174183
return kwargs
175184

176185

@@ -191,6 +200,7 @@ def create_optimizer_v2(
191200
lr: Optional[float] = None,
192201
weight_decay: float = 0.,
193202
momentum: float = 0.9,
203+
foreach: Optional[bool] = None,
194204
filter_bias_and_bn: bool = True,
195205
layer_decay: Optional[float] = None,
196206
param_group_fn: Optional[Callable] = None,
@@ -209,6 +219,7 @@ def create_optimizer_v2(
209219
lr: initial learning rate
210220
weight_decay: weight decay to apply in optimizer
211221
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
222+
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
212223
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
213224
**kwargs: extra optimizer specific kwargs to pass through
214225
@@ -228,7 +239,8 @@ def create_optimizer_v2(
228239
model_or_params,
229240
weight_decay=weight_decay,
230241
layer_decay=layer_decay,
231-
no_weight_decay_list=no_weight_decay)
242+
no_weight_decay_list=no_weight_decay,
243+
)
232244
weight_decay = 0.
233245
elif weight_decay and filter_bias_and_bn:
234246
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
@@ -246,9 +258,16 @@ def create_optimizer_v2(
246258
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
247259

248260
opt_args = dict(weight_decay=weight_decay, **kwargs)
261+
249262
if lr is not None:
250263
opt_args.setdefault('lr', lr)
251264

265+
if foreach is None:
266+
if opt in _DEFAULT_FOREACH:
267+
opt_args.setdefault('foreach', True)
268+
else:
269+
opt_args['foreach'] = foreach
270+
252271
# basic SGD & related
253272
if opt_lower == 'sgd' or opt_lower == 'nesterov':
254273
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons

0 commit comments

Comments
 (0)