@@ -129,26 +129,27 @@ def __init__(
129129 self .exclude_layers = exclude_layers or []
130130
131131 def _should_use_adamw (self , variable ):
132- # To use it with 4D convolutional filters,
133- # it works well to just flatten their last 3 dimensions.
134- # any {0,1}-D parameters should all be optimized by adam
135- if not 1 < len (variable .shape ) < 5 :
132+ """Determine if AdamW should be used for a variable."""
133+ # Use AdamW for variables not having 2D, 3D, or 4D shape
134+ if len (variable .shape ) not in (2 , 3 , 4 ):
136135 return True
137136
138- # Get variable identifier (use .name in Keras 3+ )
137+ # Get variable identifier (.name for logging and checks )
139138 var_identifier = variable .name
140139
141- # Check if embedding layer should be excluded
140+ # Exclude embedding layers
142141 if self .exclude_embeddings and "embedding" in var_identifier .lower ():
143142 return True
144143
145- # Check if variable matches any excluded layer patterns
146- for keyword in self .exclude_layers :
144+ # Exclude any user-specified layer patterns
145+ for pattern in self .exclude_layers :
147146 try :
148- if re .search (keyword , var_identifier ):
147+ if re .search (pattern , var_identifier ):
149148 return True
150- except re .error :
149+ except (re .error , TypeError ):
150+ # Skip invalid regex patterns or non-string entries
151151 continue
152+
152153 return False
153154
154155 def build (self , var_list ):
@@ -166,18 +167,13 @@ def build(self, var_list):
166167 self .adam_momentums = {}
167168 self .adam_velocities = {}
168169
169- self .muon_momentums = {}
170- self .muon_velocities = {}
171-
172170 for var in var_list :
173171 if not self ._overwrite_variable_with_gradient (var ):
174- self .adam_momentums [var .name ] = (
175- self .add_variable_from_reference (
176- reference_variable = var , name = "momentum"
177- )
172+ self .adam_momentums [id (var )] = self .add_variable_from_reference (
173+ reference_variable = var , name = "momentum"
178174 )
179175 if self ._should_use_adamw (var ):
180- self .adam_velocities [var . name ] = (
176+ self .adam_velocities [id ( var ) ] = (
181177 self .add_variable_from_reference (
182178 reference_variable = var , name = "velocity"
183179 )
@@ -193,7 +189,7 @@ def update_step(self, gradient, variable, learning_rate):
193189 self ._muon_update_step (gradient , variable , learning_rate )
194190
195191 def _muon_update_step (self , gradient , variable , lr ):
196- m = self .adam_momentums [variable . name ]
192+ m = self .adam_momentums [id ( variable ) ]
197193 self .assign_add (m , ops .add (gradient , m * (self .momentum - 1 )))
198194 shape = variable .shape
199195 if self .nesterov :
@@ -220,8 +216,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
220216 ops .cast (self .adam_beta_2 , variable .dtype ), local_step
221217 )
222218
223- m = self .adam_momentums [variable . name ]
224- v = self .adam_velocities [variable . name ]
219+ m = self .adam_momentums [id ( variable ) ]
220+ v = self .adam_velocities [id ( variable ) ]
225221
226222 alpha = lr * ops .sqrt (1 - adam_beta_2_power ) / (1 - adam_beta_1_power )
227223
0 commit comments