1616# See the License for the specific language governing permissions and
1717# limitations under the License.
1818# ==============================================================================
19+ from typing import List
20+
1921import torch
2022from torch .optim .optimizer import Optimizer
2123
2224
2325class Lion (Optimizer ):
2426 r"""Implements Lion algorithm."""
2527
26- def __init__ (self , params , lr = 1e-4 , betas = (0.9 , 0.99 ), weight_decay = 0.0 ):
28+ def __init__ (
29+ self ,
30+ params ,
31+ lr = 1e-4 ,
32+ betas = (0.9 , 0.99 ),
33+ weight_decay = 0.0 ,
34+ maximize = False ,
35+ foreach = None ,
36+ ):
2737 """Initialize the hyperparameters.
2838
2939 Args:
@@ -41,9 +51,21 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
4151 raise ValueError ('Invalid beta parameter at index 0: {}' .format (betas [0 ]))
4252 if not 0.0 <= betas [1 ] < 1.0 :
4353 raise ValueError ('Invalid beta parameter at index 1: {}' .format (betas [1 ]))
44- defaults = dict (lr = lr , betas = betas , weight_decay = weight_decay )
54+ defaults = dict (
55+ lr = lr ,
56+ betas = betas ,
57+ weight_decay = weight_decay ,
58+ foreach = foreach ,
59+ maximize = maximize ,
60+ )
4561 super ().__init__ (params , defaults )
4662
63+ def __setstate__ (self , state ):
64+ super ().__setstate__ (state )
65+ for group in self .param_groups :
66+ group .setdefault ('maximize' , False )
67+ group .setdefault ('foreach' , None )
68+
4769 @torch .no_grad ()
4870 def step (self , closure = None ):
4971 """Performs a single optimization step.
@@ -61,27 +83,144 @@ def step(self, closure=None):
6183 loss = closure ()
6284
6385 for group in self .param_groups :
86+ params_with_grad = []
87+ grads = []
88+ exp_avgs = []
89+ beta1 , beta2 = group ['betas' ]
90+
6491 for p in group ['params' ]:
6592 if p .grad is None :
6693 continue
94+ params_with_grad .append (p )
95+ if p .grad .is_sparse :
96+ raise RuntimeError ('Lion does not support sparse gradients' )
97+ grads .append (p .grad )
6798
68- # Perform stepweight decay
69- p .data .mul_ (1 - group ['lr' ] * group ['weight_decay' ])
70-
71- grad = p .grad
7299 state = self .state [p ]
100+
73101 # State initialization
74102 if len (state ) == 0 :
75- # Exponential moving average of gradient values
76- state ['exp_avg' ] = torch .zeros_like (p )
103+ state ['exp_avg' ] = torch .zeros_like (p , memory_format = torch .preserve_format )
77104
78- exp_avg = state ['exp_avg' ]
79- beta1 , beta2 = group ['betas' ]
105+ exp_avgs .append (state ['exp_avg' ])
80106
81- # Weight update
82- update = exp_avg * beta1 + grad * (1 - beta1 )
83- p .add_ (torch .sign (update ), alpha = - group ['lr' ])
84- # Decay the momentum running average coefficient
85- exp_avg .mul_ (beta2 ).add_ (grad , alpha = 1 - beta2 )
107+ lion (
108+ params_with_grad ,
109+ grads ,
110+ exp_avgs ,
111+ beta1 = beta1 ,
112+ beta2 = beta2 ,
113+ lr = group ['lr' ],
114+ weight_decay = group ['weight_decay' ],
115+ maximize = group ['maximize' ],
116+ foreach = group ['foreach' ],
117+ )
86118
87119 return loss
120+
121+
122+ def lion (
123+ params : List [torch .Tensor ],
124+ grads : List [torch .Tensor ],
125+ exp_avgs : List [torch .Tensor ],
126+ # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
127+ # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
128+ maximize : bool = False ,
129+ foreach : bool = None ,
130+ * ,
131+ beta1 : float ,
132+ beta2 : float ,
133+ lr : float ,
134+ weight_decay : float ,
135+ ):
136+ r"""Functional API that performs Lion algorithm computation.
137+ """
138+ if foreach is None :
139+ # Placeholder for more complex foreach logic to be added when value is not set
140+ foreach = False
141+
142+ if foreach and torch .jit .is_scripting ():
143+ raise RuntimeError ('torch.jit.script not supported with foreach optimizers' )
144+
145+ if foreach and not torch .jit .is_scripting ():
146+ func = _multi_tensor_lion
147+ else :
148+ func = _single_tensor_lion
149+
150+ func (
151+ params ,
152+ grads ,
153+ exp_avgs ,
154+ beta1 = beta1 ,
155+ beta2 = beta2 ,
156+ lr = lr ,
157+ weight_decay = weight_decay ,
158+ maximize = maximize ,
159+ )
160+
161+
162+ def _single_tensor_lion (
163+ params : List [torch .Tensor ],
164+ grads : List [torch .Tensor ],
165+ exp_avgs : List [torch .Tensor ],
166+ * ,
167+ beta1 : float ,
168+ beta2 : float ,
169+ lr : float ,
170+ weight_decay : float ,
171+ maximize : bool ,
172+ ):
173+ for i , param in enumerate (params ):
174+ grad = grads [i ] if not maximize else - grads [i ]
175+ exp_avg = exp_avgs [i ]
176+
177+ if torch .is_complex (param ):
178+ grad = torch .view_as_real (grad )
179+ exp_avg = torch .view_as_real (exp_avg )
180+ param = torch .view_as_real (param )
181+
182+ # Perform stepweight decay
183+ param .mul_ (1 - lr * weight_decay )
184+
185+ # Weight update
186+ update = exp_avg .mul (beta1 ).add_ (grad , alpha = 1 - beta1 )
187+ param .add_ (torch .sign (update ), alpha = - lr )
188+
189+ # Decay the momentum running average coefficient
190+ exp_avg .lerp_ (grad , 1 - beta2 )
191+
192+
193+ def _multi_tensor_lion (
194+ params : List [torch .Tensor ],
195+ grads : List [torch .Tensor ],
196+ exp_avgs : List [torch .Tensor ],
197+ * ,
198+ beta1 : float ,
199+ beta2 : float ,
200+ lr : float ,
201+ weight_decay : float ,
202+ maximize : bool ,
203+ ):
204+ if len (params ) == 0 :
205+ return
206+
207+ if maximize :
208+ grads = torch ._foreach_neg (tuple (grads )) # type: ignore[assignment]
209+
210+ grads = [torch .view_as_real (x ) if torch .is_complex (x ) else x for x in grads ]
211+ exp_avgs = [torch .view_as_real (x ) if torch .is_complex (x ) else x for x in exp_avgs ]
212+ params = [torch .view_as_real (x ) if torch .is_complex (x ) else x for x in params ]
213+
214+ # Perform stepweight decay
215+ torch ._foreach_mul_ (params , 1 - lr * weight_decay )
216+
217+ # Weight update
218+ updates = torch ._foreach_mul (exp_avgs , beta1 )
219+ torch ._foreach_add_ (updates , grads , alpha = 1 - beta1 )
220+
221+ updates = [u .sign () for u in updates ]
222+ torch ._foreach_add_ (params , updates , alpha = - lr )
223+
224+ # Decay the momentum running average coefficient
225+ torch ._foreach_mul_ (exp_avgs , beta2 )
226+ torch ._foreach_add_ (exp_avgs , grads , alpha = 1 - beta2 )
0 commit comments