Skip to content

Commit 709d5e0

Browse files
committed
Add Lion optimizer
1 parent 6242661 commit 709d5e0

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

timm/optim/lion.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
""" Lion Optimizer
2+
Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675
3+
Original Impl: https://github.com/google/automl/tree/master/lion
4+
"""
5+
# Copyright 2023 Google Research. All Rights Reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
# ==============================================================================
19+
import torch
20+
from torch.optim.optimizer import Optimizer
21+
22+
23+
class Lion(Optimizer):
24+
r"""Implements Lion algorithm."""
25+
26+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
27+
"""Initialize the hyperparameters.
28+
29+
Args:
30+
params (iterable): iterable of parameters to optimize or dicts defining
31+
parameter groups
32+
lr (float, optional): learning rate (default: 1e-4)
33+
betas (Tuple[float, float], optional): coefficients used for computing
34+
running averages of gradient and its square (default: (0.9, 0.99))
35+
weight_decay (float, optional): weight decay coefficient (default: 0)
36+
"""
37+
38+
if not 0.0 <= lr:
39+
raise ValueError('Invalid learning rate: {}'.format(lr))
40+
if not 0.0 <= betas[0] < 1.0:
41+
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
42+
if not 0.0 <= betas[1] < 1.0:
43+
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
44+
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
45+
super().__init__(params, defaults)
46+
47+
@torch.no_grad()
48+
def step(self, closure=None):
49+
"""Performs a single optimization step.
50+
51+
Args:
52+
closure (callable, optional): A closure that reevaluates the model
53+
and returns the loss.
54+
55+
Returns:
56+
the loss.
57+
"""
58+
loss = None
59+
if closure is not None:
60+
with torch.enable_grad():
61+
loss = closure()
62+
63+
for group in self.param_groups:
64+
for p in group['params']:
65+
if p.grad is None:
66+
continue
67+
68+
# Perform stepweight decay
69+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
70+
71+
grad = p.grad
72+
state = self.state[p]
73+
# State initialization
74+
if len(state) == 0:
75+
# Exponential moving average of gradient values
76+
state['exp_avg'] = torch.zeros_like(p)
77+
78+
exp_avg = state['exp_avg']
79+
beta1, beta2 = group['betas']
80+
81+
# Weight update
82+
update = exp_avg * beta1 + grad * (1 - beta1)
83+
p.add_(torch.sign(update), alpha=-group['lr'])
84+
# Decay the momentum running average coefficient
85+
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
86+
87+
return loss

timm/optim/optim_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .adan import Adan
1919
from .lamb import Lamb
2020
from .lars import Lars
21+
from .lion import Lion
2122
from .lookahead import Lookahead
2223
from .madgrad import MADGRAD
2324
from .nadam import Nadam
@@ -313,6 +314,8 @@ def create_optimizer_v2(
313314
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
314315
elif opt_lower == 'rmsproptf':
315316
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
317+
elif opt_lower == 'lion':
318+
optimizer = Lion(parameters, **opt_args)
316319

317320
# second order
318321
elif opt_lower == 'adahessian':

0 commit comments

Comments
 (0)