1414# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
1515# SPDX-License-Identifier: Apache-2.0
1616import math
17+ from typing import Optional , Tuple
1718
1819import torch
1920from torch .optim .optimizer import Optimizer
2021
21-
22- def mars_single_tensor (
23- p ,
24- grad ,
25- exp_avg ,
26- exp_avg_sq ,
27- lr ,
28- weight_decay ,
29- beta1 ,
30- beta2 ,
31- last_grad ,
32- eps ,
33- step ,
34- gamma ,
35- mars_type ,
36- is_grad_2d ,
37- optimize_1d ,
38- lr_1d_factor ,
39- betas_1d ,
22+ from ._types import ParamsT
23+
24+
25+ def _mars_single_tensor_step (
26+ p : torch .Tensor ,
27+ grad : torch .Tensor ,
28+ exp_avg : torch .Tensor ,
29+ exp_avg_sq : torch .Tensor ,
30+ lr : float ,
31+ weight_decay : float ,
32+ beta1 : float ,
33+ beta2 : float ,
34+ last_grad : torch .Tensor ,
35+ eps : float ,
36+ step : int ,
37+ gamma : float ,
38+ mars_type : str ,
39+ is_grad_2d : bool ,
40+ optimize_1d : bool ,
41+ lr_1d_factor : bool ,
42+ betas_1d : Tuple [float , float ],
43+ caution : bool ,
4044):
41- # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
45+ # optimize_1d ==> use MARS for 1d param, else use AdamW
4246 if optimize_1d or is_grad_2d :
4347 one_minus_beta1 = 1. - beta1
44- c_t = (grad - last_grad ).mul_ (gamma * (beta1 / one_minus_beta1 )).add_ (grad )
45- c_t_norm = torch .norm (c_t )
46- if c_t_norm > 1. :
47- c_t = c_t / c_t_norm
48+ if step == 1 :
49+ # this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
50+ c_t = grad
51+ else :
52+ c_t = (grad - last_grad ).mul_ (gamma * (beta1 / one_minus_beta1 )).add_ (grad )
53+ c_t_norm = torch .norm (c_t )
54+ if c_t_norm > 1. :
55+ c_t = c_t / c_t_norm
4856 exp_avg .mul_ (beta1 ).add_ (c_t , alpha = one_minus_beta1 )
57+ if caution :
58+ mask = (exp_avg * grad > 0 ).to (grad .dtype )
59+ mask .div_ (mask .mean ().clamp_ (min = 1e-3 ))
60+ exp_avg = exp_avg * mask
4961 if mars_type == "adamw" :
5062 exp_avg_sq .mul_ (beta2 ).addcmul_ (c_t , c_t , value = 1. - beta2 )
5163 bias_correction1 = 1.0 - beta1 ** step
@@ -64,6 +76,10 @@ def mars_single_tensor(
6476 bias_correction1 = 1.0 - beta1_1d ** step
6577 bias_correction2 = 1.0 - beta2_1d ** step
6678 denom = (exp_avg_sq .sqrt () / math .sqrt (bias_correction2 )).add_ (eps )
79+ if caution :
80+ mask = (exp_avg * grad > 0 ).to (grad .dtype )
81+ mask .div_ (mask .mean ().clamp_ (min = 1e-3 ))
82+ exp_avg = exp_avg * mask
6783 update = p * weight_decay + (exp_avg / bias_correction1 ).div_ (denom )
6884 p .add_ (update , alpha = - (lr * lr_1d_factor ))
6985 return exp_avg , exp_avg_sq
@@ -78,16 +94,17 @@ class Mars(Optimizer):
7894 """
7995 def __init__ (
8096 self ,
81- params ,
82- lr = 3e-3 ,
83- betas = (0.9 , 0.99 ),
84- eps = 1e-8 ,
85- weight_decay = 0. ,
86- gamma = 0.025 ,
87- mars_type = "adamw" ,
88- optimize_1d = False ,
89- lr_1d_factor = 1.0 ,
90- betas_1d = None ,
97+ params : ParamsT ,
98+ lr : float = 3e-3 ,
99+ betas : Tuple [float , float ] = (0.9 , 0.99 ),
100+ eps : float = 1e-8 ,
101+ weight_decay : float = 0. ,
102+ gamma : float = 0.025 ,
103+ mars_type : str = "adamw" ,
104+ optimize_1d : bool = False ,
105+ lr_1d_factor : float = 1.0 ,
106+ betas_1d : Optional [Tuple [float , float ]] = None ,
107+ caution : bool = False
91108 ):
92109 if not 0.0 <= lr :
93110 raise ValueError ("Invalid learning rate: {}" .format (lr ))
@@ -109,9 +126,15 @@ def __init__(
109126 optimize_1d = optimize_1d ,
110127 lr_1d_factor = lr_1d_factor ,
111128 betas_1d = betas_1d or betas ,
129+ caution = caution ,
112130 )
113131 super (Mars , self ).__init__ (params , defaults )
114132
133+ def __setstate__ (self , state ):
134+ super (Mars , self ).__setstate__ (state )
135+ for group in self .param_groups :
136+ group .setdefault ('caution' , False )
137+
115138 @torch .no_grad ()
116139 def step (self , closure = None ):
117140 """Performs a single optimization step.
@@ -134,7 +157,6 @@ def step(self, closure=None):
134157 raise RuntimeError ('Adam does not support sparse gradients, please consider SparseAdam instead' )
135158
136159 state = self .state [p ]
137- # ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
138160 # State initialization
139161 if len (state ) <= 1 :
140162 state ['step' ] = 0
@@ -155,7 +177,8 @@ def step(self, closure=None):
155177 beta1 , beta2 = group ['betas' ]
156178 is_grad_2d = grad .ndim >= 2
157179
158- mars_single_tensor (
180+ # FIXME add multi-tensor (if usage warrants), make more standard
181+ _mars_single_tensor_step (
159182 p ,
160183 grad ,
161184 exp_avg ,
@@ -173,6 +196,7 @@ def step(self, closure=None):
173196 optimize_1d = group ['optimize_1d' ],
174197 lr_1d_factor = group ['lr_1d_factor' ],
175198 betas_1d = group ['betas_1d' ],
199+ caution = group ['caution' ],
176200 )
177201
178202 state ['last_grad' ] = grad
0 commit comments