1313
1414"""
1515
16- from typing import cast , List , Optional , Tuple , Union
16+ from typing import cast , Callable , List , Optional , Tuple , Union
1717
1818import torch
1919from torch import Tensor
@@ -64,6 +64,7 @@ def __init__(
6464 lr : Union [float , Tensor ] = 1e-3 ,
6565 betas : Tuple [float , float ] = (0.9 , 0.9999 ),
6666 eps : float = 1e-6 ,
67+ clip_exp : Optional [float ] = 0.333 ,
6768 weight_decay : float = 0.0 ,
6869 decoupled : bool = False ,
6970 * ,
@@ -95,6 +96,7 @@ def __init__(
9596 betas = betas ,
9697 eps = eps ,
9798 weight_decay = weight_decay ,
99+ clip_exp = clip_exp ,
98100 decoupled = decoupled ,
99101 maximize = maximize ,
100102 foreach = foreach ,
@@ -111,6 +113,7 @@ def __setstate__(self, state):
111113 group .setdefault ("foreach" , None )
112114 group .setdefault ("capturable" , False )
113115 group .setdefault ("differentiable" , False )
116+ group .setdefault ("clip_exp" , None )
114117 for p in group ["params" ]:
115118 p_state = self .state .get (p , [])
116119 if len (p_state ) != 0 and not torch .is_tensor (p_state ["step" ]):
@@ -141,9 +144,7 @@ def _init_group(
141144 has_complex |= torch .is_complex (p )
142145 params_with_grad .append (p )
143146 if p .grad .is_sparse :
144- raise RuntimeError (
145- "ADOPT does not support sparse gradients"
146- )
147+ raise RuntimeError ("ADOPT does not support sparse gradients" )
147148 grads .append (p .grad )
148149
149150 state = self .state [p ]
@@ -153,36 +154,24 @@ def _init_group(
153154 # Deliberately host `step` on CPU if both capturable and fused are off.
154155 # This is because kernel launches are costly on CUDA and XLA.
155156 state ["step" ] = (
156- torch .zeros (
157- (),
158- dtype = _get_scalar_dtype (),
159- device = p .grad .device ,
160- )
157+ torch .zeros ((), dtype = _get_scalar_dtype (), device = p .grad .device )
161158 if group ["capturable" ]
162159 else torch .tensor (0.0 , dtype = _get_scalar_dtype ())
163160 )
164161 # Exponential moving average of gradient values
165- state ["exp_avg" ] = torch .zeros_like (
166- p .grad , memory_format = torch .preserve_format
167- )
162+ state ["exp_avg" ] = torch .zeros_like (p .grad , memory_format = torch .preserve_format )
168163 # Exponential moving average of squared gradient values
169- state ["exp_avg_sq" ] = torch .zeros_like (
170- p .grad , memory_format = torch .preserve_format
171- )
164+ state ["exp_avg_sq" ] = torch .zeros_like (p .grad , memory_format = torch .preserve_format )
172165
173166 exp_avgs .append (state ["exp_avg" ])
174167 exp_avg_sqs .append (state ["exp_avg_sq" ])
175168
176169 if group ["differentiable" ] and state ["step" ].requires_grad :
177- raise RuntimeError (
178- "`requires_grad` is not supported for `step` in differentiable mode"
179- )
170+ raise RuntimeError ("`requires_grad` is not supported for `step` in differentiable mode" )
180171
181172 # Foreach without capturable does not support a tensor lr
182173 if group ["foreach" ] and torch .is_tensor (group ["lr" ]) and not group ["capturable" ]:
183- raise RuntimeError (
184- "lr as a Tensor is not supported for capturable=False and foreach=True"
185- )
174+ raise RuntimeError ("lr as a Tensor is not supported for capturable=False and foreach=True" )
186175
187176 state_steps .append (state ["step" ])
188177 return has_complex
@@ -231,6 +220,7 @@ def step(self, closure=None):
231220 beta2 = beta2 ,
232221 lr = group ["lr" ],
233222 weight_decay = group ["weight_decay" ],
223+ clip_exp = group ["clip_exp" ],
234224 decoupled = group ["decoupled" ],
235225 eps = group ["eps" ],
236226 maximize = group ["maximize" ],
@@ -258,6 +248,7 @@ def _single_tensor_adopt(
258248 beta2 : float ,
259249 lr : Union [float , Tensor ],
260250 weight_decay : float ,
251+ clip_exp : Optional [float ],
261252 decoupled : bool ,
262253 eps : float ,
263254 maximize : bool ,
@@ -282,20 +273,12 @@ def _single_tensor_adopt(
282273 if capturable and not _is_compiling ():
283274 from torch .optim .optimizer import _get_capturable_supported_devices
284275 capturable_supported_devices = _get_capturable_supported_devices ()
285- assert (
286- param .device .type == step_t .device .type
287- and param .device .type in capturable_supported_devices
288- ), f"If capturable=True, params and state_steps must be on supported devices: { capturable_supported_devices } ."
276+ assert param .device .type == step_t .device .type and param .device .type in capturable_supported_devices ,\
277+ f"If capturable=True, params and state_steps must be on supported devices: { capturable_supported_devices } ."
289278
290279 # update step
291280 step_t += 1
292281
293- if weight_decay != 0 :
294- if decoupled :
295- param .add_ (param , alpha = - lr * weight_decay )
296- else :
297- grad = grad .add (param , alpha = weight_decay )
298-
299282 if torch .is_complex (param ):
300283 grad = torch .view_as_real (grad )
301284 if exp_avg is not None :
@@ -304,17 +287,25 @@ def _single_tensor_adopt(
304287 exp_avg_sq = torch .view_as_real (exp_avg_sq )
305288 param = torch .view_as_real (param )
306289
290+ if weight_decay != 0 and not decoupled :
291+ grad = grad .add (param , alpha = weight_decay )
292+
307293 step = step_t if capturable or differentiable else _get_value (step_t )
308294 if step == 1 :
309295 exp_avg_sq .addcmul_ (grad , grad .conj ())
310296 continue
311297
298+ if weight_decay != 0 and decoupled :
299+ param .add_ (param , alpha = - lr * weight_decay )
300+
312301 denom = torch .clamp (exp_avg_sq .sqrt (), eps )
313- if step == 2 :
314- exp_avg .addcdiv_ (grad , denom )
315- else :
316- exp_avg .mul_ (beta1 ).addcdiv_ (grad , denom , value = 1 - beta1 )
302+ normed_grad = grad .div (denom )
303+
304+ if clip_exp is not None :
305+ clip_val = (step - 1 ) ** clip_exp
306+ normed_grad .clamp_ (- clip_val , clip_val )
317307
308+ exp_avg .lerp_ (normed_grad , 1 - beta1 )
318309 param .add_ (exp_avg , alpha = - lr )
319310
320311 exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad .conj (), value = 1 - beta2 )
@@ -334,6 +325,7 @@ def _multi_tensor_adopt(
334325 beta2 : float ,
335326 lr : Union [float , Tensor ],
336327 weight_decay : float ,
328+ clip_exp : Optional [float ],
337329 decoupled : bool ,
338330 eps : float ,
339331 maximize : bool ,
@@ -355,8 +347,7 @@ def _multi_tensor_adopt(
355347 supports_xla = False
356348 )
357349 assert all (
358- p .device .type == step .device .type
359- and p .device .type in capturable_supported_devices
350+ p .device .type == step .device .type and p .device .type in capturable_supported_devices
360351 for p , step in zip (params , state_steps )
361352 ), f"If capturable=True, params and state_steps must be on supported devices: { capturable_supported_devices } ."
362353
@@ -382,9 +373,7 @@ def _multi_tensor_adopt(
382373
383374 # Handle complex parameters
384375 if has_complex :
385- _view_as_real (
386- device_params , device_grads , device_exp_avgs , device_exp_avg_sqs
387- )
376+ _view_as_real (device_params , device_grads , device_exp_avgs , device_exp_avg_sqs )
388377
389378 if maximize :
390379 device_grads = torch ._foreach_neg (device_grads ) # type: ignore[assignment]
@@ -394,44 +383,38 @@ def _multi_tensor_adopt(
394383 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
395384 # wrapped it once now. The alpha is required to assure we go to the right overload.
396385 if not _is_compiling () and device_state_steps [0 ].is_cpu :
397- torch ._foreach_add_ (
398- device_state_steps , torch .tensor (1.0 , device = "cpu" ), alpha = 1.0
399- )
386+ torch ._foreach_add_ (device_state_steps , torch .tensor (1.0 , device = "cpu" ), alpha = 1.0 )
400387 else :
401388 torch ._foreach_add_ (device_state_steps , 1 )
402389
403- if weight_decay != 0 :
404- if decoupled :
405- torch ._foreach_add_ (device_params , device_params , alpha = - lr * weight_decay )
390+ if weight_decay != 0 and not decoupled :
391+ # Re-use the intermediate memory (device_grads) already allocated for maximize
392+ if maximize :
393+ torch ._foreach_add_ (device_grads , device_params , alpha = weight_decay )
406394 else :
407- # Re-use the intermediate memory (device_grads) already allocated for maximize
408- if maximize :
409- torch ._foreach_add_ (device_grads , device_params , alpha = weight_decay )
410- else :
411- device_grads = torch ._foreach_add ( # type: ignore[assignment]
412- device_grads , device_params , alpha = weight_decay
413- )
395+ device_grads = torch ._foreach_add (device_grads , device_params , alpha = weight_decay )
414396
415397 if device_state_steps [0 ] == 1 :
416398 torch ._foreach_addcmul_ (device_exp_avg_sqs , device_grads , device_grads )
417399 continue
418400
401+ if weight_decay != 0 and decoupled :
402+ torch ._foreach_add_ (device_params , device_params , alpha = - lr * weight_decay )
403+
419404 exp_avg_sq_sqrt = torch ._foreach_sqrt (device_exp_avg_sqs )
420- exp_avg_sq_sqrt = torch ._foreach_maximum (exp_avg_sq_sqrt , eps )
405+ torch ._foreach_maximum_ (exp_avg_sq_sqrt , eps )
406+ normed_grad = torch ._foreach_div (device_grads , exp_avg_sq_sqrt )
421407
422- if device_state_steps [0 ] == 2 :
423- torch ._foreach_addcdiv_ (device_exp_avgs , device_grads , exp_avg_sq_sqrt )
424- else :
425- torch ._foreach_mul_ (device_exp_avgs , beta1 )
426- torch ._foreach_addcdiv_ (
427- device_exp_avgs , device_grads , exp_avg_sq_sqrt , value = 1 - beta1
428- )
408+ if clip_exp is not None :
409+ clip_val = (device_state_steps [0 ] - 1 ) ** clip_exp
410+ torch ._foreach_maximum_ (normed_grad , - clip_val )
411+ torch ._foreach_minimum_ (normed_grad , clip_val )
429412
413+ torch ._foreach_lerp_ (device_exp_avgs , normed_grad , 1 - beta1 )
430414 torch ._foreach_add_ (device_params , device_exp_avgs , alpha = - lr )
415+
431416 torch ._foreach_mul_ (device_exp_avg_sqs , beta2 )
432- torch ._foreach_addcmul_ (
433- device_exp_avg_sqs , device_grads , device_grads , value = 1 - beta2
434- )
417+ torch ._foreach_addcmul_ (device_exp_avg_sqs , device_grads , device_grads , value = 1 - beta2 )
435418
436419
437420#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
@@ -454,6 +437,7 @@ def adopt(
454437 beta2 : float ,
455438 lr : Union [float , Tensor ],
456439 weight_decay : float ,
440+ clip_exp : Optional [float ],
457441 decoupled : bool ,
458442 eps : float ,
459443 maximize : bool ,
@@ -490,6 +474,7 @@ def adopt(
490474 beta2 = beta2 ,
491475 lr = lr ,
492476 weight_decay = weight_decay ,
477+ clip_exp = clip_exp ,
493478 decoupled = decoupled ,
494479 eps = eps ,
495480 maximize = maximize ,
0 commit comments