1515import numpy as np
1616import torch
1717
18- try :
19- # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
20- import opt_einsum
21- opt_einsum .enabled = True
22- opt_einsum .strategy = "auto-hq"
23- import torch .backends .opt_einsum
24- has_opt_einsum = True
25- except ImportError :
26- has_opt_einsum = False
2718
2819try :
2920 torch ._dynamo .config .cache_size_limit = 1_000_000
@@ -67,19 +58,20 @@ class Kron(torch.optim.Optimizer):
6758 params: Iterable of parameters to optimize or dicts defining parameter groups.
6859 lr: Learning rate.
6960 momentum: Momentum parameter.
70- weight_decay: Weight decay (L2 penalty) .
61+ weight_decay: Weight decay.
7162 preconditioner_update_probability: Probability of updating the preconditioner.
7263 If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
7364 max_size_triangular: Max size for dim's preconditioner to be triangular.
7465 min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
75- memory_save_mode: 'one_diag', or 'all_diag', None is default
66+ memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
7667 to set all preconditioners to be triangular, 'one_diag' sets the largest
7768 or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
7869 momentum_into_precond_update: whether to send momentum into preconditioner
7970 update instead of raw gradients.
8071 mu_dtype: Dtype of the momentum accumulator.
8172 precond_dtype: Dtype of the preconditioner.
82- decoupled_decay: AdamW style decoupled-decay.
73+ decoupled_decay: AdamW style decoupled weight decay
74+ flatten_dim: Flatten dim >= 2 instead of relying on expressions
8375 deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
8476 """
8577
@@ -97,10 +89,18 @@ def __init__(
9789 mu_dtype : Optional [torch .dtype ] = None ,
9890 precond_dtype : Optional [torch .dtype ] = None ,
9991 decoupled_decay : bool = False ,
92+ flatten_dim : bool = False ,
10093 deterministic : bool = False ,
10194 ):
102- if not has_opt_einsum :
95+ try :
96+ # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
97+ import opt_einsum
98+ opt_einsum .enabled = True
99+ opt_einsum .strategy = "auto-hq"
100+ import torch .backends .opt_einsum
101+ except ImportError :
103102 warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103+
104104 if not 0.0 <= lr :
105105 raise ValueError (f"Invalid learning rate: { lr } " )
106106 if not 0.0 <= momentum < 1.0 :
@@ -122,10 +122,11 @@ def __init__(
122122 mu_dtype = mu_dtype ,
123123 precond_dtype = precond_dtype ,
124124 decoupled_decay = decoupled_decay ,
125+ flatten_dim = flatten_dim ,
125126 )
126127 super (Kron , self ).__init__ (params , defaults )
127128
128- self ._param_exprs = {}
129+ self ._param_exprs = {} # cache for einsum expr
129130 self ._tiny = torch .finfo (torch .bfloat16 ).tiny
130131 self .rng = random .Random (1337 )
131132 if deterministic :
@@ -165,20 +166,21 @@ def state_dict(self) -> Dict[str, Any]:
165166
166167 def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
167168 # Extract and remove the RNG state from the state dict
168- rng_state = state_dict .pop ('rng_state' , None )
169- torch_rng_state = state_dict .pop ('torch_rng_state' , None )
169+ rng_states = {}
170+ if 'rng_state' in state_dict :
171+ rng_states ['rng_state' ] = state_dict .pop ('rng_state' )
172+ if 'torch_rng_state' in state_dict :
173+ rng_states ['torch_rng_state' ] = state_dict .pop ('torch_rng_state' )
170174
171175 # Load the optimizer state
172176 super ().load_state_dict (state_dict )
177+ state_dict .update (rng_states ) # add back
173178
174179 # Restore the RNG state if it exists
175- if rng_state is not None :
176- self .rng .setstate (rng_state )
177- state_dict ['rng_state' ] = rng_state # put it back if caller still using state_dict
178- if torch_rng_state is not None :
179- if self .torch_rng is not None :
180- self .torch_rng .set_state (torch_rng_state )
181- state_dict ['torch_rng_state' ] = torch_rng_state # put it back if caller still using state_dict
180+ if 'rng_state' in rng_states :
181+ self .rng .setstate (rng_states ['rng_state' ])
182+ if 'torch_rng_state' in rng_states :
183+ self .torch_rng .set_state (rng_states ['torch_rng_state' ])
182184
183185 def __setstate__ (self , state ):
184186 super ().__setstate__ (state )
@@ -208,13 +210,16 @@ def step(self, closure=None):
208210
209211 grad = p .grad
210212 state = self .state [p ]
213+ if group ['flatten_dim' ]:
214+ grad = grad .view (grad .size (0 ), - 1 )
211215
212216 if len (state ) == 0 :
213217 state ["step" ] = 0
214218 state ["update_counter" ] = 0
215- state ["momentum_buffer" ] = torch .zeros_like (p , dtype = mu_dtype or p .dtype )
219+ state ["momentum_buffer" ] = torch .zeros_like (grad , dtype = mu_dtype or grad .dtype )
220+ # init Q and einsum expressions on first step
216221 state ["Q" ], exprs = _init_Q_exprs (
217- p ,
222+ grad ,
218223 group ["precond_init_scale" ],
219224 group ["max_size_triangular" ],
220225 group ["min_ndim_triangular" ],
@@ -234,8 +239,9 @@ def step(self, closure=None):
234239 total_precond_size += precond_size
235240 total_precond_mb += precond_mb
236241 elif p not in self ._param_exprs :
242+ # init only the einsum expressions, called after state load, Q are loaded from state_dict
237243 exprs = _init_Q_exprs (
238- p ,
244+ grad ,
239245 group ["precond_init_scale" ],
240246 group ["max_size_triangular" ],
241247 group ["min_ndim_triangular" ],
@@ -245,6 +251,7 @@ def step(self, closure=None):
245251 )
246252 self ._param_exprs [p ] = exprs
247253 else :
254+ # retrieve cached expressions
248255 exprs = self ._param_exprs [p ]
249256
250257 # update preconditioners all together deterministically
@@ -315,6 +322,8 @@ def step(self, closure=None):
315322
316323 # RMS of pre_grad should be 1.0, so let's cap at 1.1
317324 pre_grad .mul_ (torch .clamp (1.1 / (pre_grad .square ().mean ().sqrt_ () + 1e-8 ), max = 1.0 ))
325+ if group ['flatten_dim' ]:
326+ pre_grad = pre_grad .view (p .shape )
318327
319328 # Apply weight decay
320329 if group ["weight_decay" ] != 0 :
@@ -369,9 +378,10 @@ def _init_Q_exprs(
369378 dim_diag = [False for _ in shape ]
370379 dim_diag [rev_sorted_dims [0 ]] = True
371380 elif memory_save_mode == "smart_one_diag" :
372- dim_diag = [ False for _ in shape ]
381+ # addition proposed by Lucas Nestler
373382 rev_sorted_dims = np .argsort (shape )[::- 1 ]
374383 sorted_shape = sorted (shape )
384+ dim_diag = [False for _ in shape ]
375385 if len (shape ) >= 2 and sorted_shape [- 1 ] > sorted_shape [- 2 ]:
376386 dim_diag [rev_sorted_dims [0 ]] = True
377387 elif memory_save_mode == "all_diag" :
0 commit comments