@@ -94,7 +94,8 @@ class Kron(torch.optim.Optimizer):
9494 mu_dtype: Dtype of the momentum accumulator.
9595 precond_dtype: Dtype of the preconditioner.
9696 decoupled_decay: AdamW style decoupled weight decay
97- flatten_dim: Flatten dim >= 2 instead of relying on expressions
97+ flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
98+ flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
9899 deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
99100 """
100101
@@ -114,7 +115,8 @@ def __init__(
114115 mu_dtype : Optional [torch .dtype ] = None ,
115116 precond_dtype : Optional [torch .dtype ] = None ,
116117 decoupled_decay : bool = False ,
117- flatten_dim : bool = False ,
118+ flatten : bool = False ,
119+ flatten_start_end : Tuple [int , int ] = (2 , - 1 ),
118120 deterministic : bool = False ,
119121 ):
120122 if not has_opt_einsum :
@@ -141,7 +143,8 @@ def __init__(
141143 mu_dtype = mu_dtype ,
142144 precond_dtype = precond_dtype ,
143145 decoupled_decay = decoupled_decay ,
144- flatten_dim = flatten_dim ,
146+ flatten = flatten ,
147+ flatten_start_end = flatten_start_end ,
145148 )
146149 super (Kron , self ).__init__ (params , defaults )
147150
@@ -229,8 +232,11 @@ def step(self, closure=None):
229232
230233 grad = p .grad
231234 state = self .state [p ]
232- if group ['flatten_dim' ]:
233- grad = grad .view (grad .size (0 ), - 1 )
235+
236+ flattened = False
237+ if group ['flatten' ]:
238+ grad = safe_flatten (grad , * group ["flatten_start_end" ])
239+ flattened = True
234240
235241 if len (state ) == 0 :
236242 state ["step" ] = 0
@@ -341,7 +347,7 @@ def step(self, closure=None):
341347
342348 # RMS of pre_grad should be 1.0, so let's cap at 1.1
343349 pre_grad .mul_ (torch .clamp (1.1 / (pre_grad .square ().mean ().sqrt_ () + 1e-8 ), max = 1.0 ))
344- if group [ 'flatten_dim' ] :
350+ if flattened :
345351 pre_grad = pre_grad .view (p .shape )
346352
347353 # Apply weight decay
@@ -361,6 +367,20 @@ def step(self, closure=None):
361367 return loss
362368
363369
370+ def safe_flatten (tensor , start_dim = 0 , end_dim = - 1 ):
371+ ndim = tensor .ndim
372+
373+ # Convert negative end_dim to positive and clip to end
374+ end_dim = min (end_dim if end_dim >= 0 else ndim + end_dim , ndim - 1 )
375+
376+ # If tensor has fewer dims than start_dim or start > end, return tensor as is
377+ if ndim <= start_dim or start_dim > end_dim :
378+ return tensor
379+
380+ # Now safe to flatten
381+ return tensor .flatten (start_dim , end_dim )
382+
383+
364384def _init_Q_exprs (
365385 t ,
366386 scale ,
0 commit comments