Skip to content

Commit 7bca156

Browse files
committed
Fix Muon optimizer TF 2.16+ compatibility: replace .path with id(variable) for uniqueness
1 parent ad582fa commit 7bca156

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

keras/src/optimizers/muon.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -129,26 +129,27 @@ def __init__(
129129
self.exclude_layers = exclude_layers or []
130130

131131
def _should_use_adamw(self, variable):
132-
# To use it with 4D convolutional filters,
133-
# it works well to just flatten their last 3 dimensions.
134-
# any {0,1}-D parameters should all be optimized by adam
135-
if not 1 < len(variable.shape) < 5:
132+
"""Determine if AdamW should be used for a variable."""
133+
# Use AdamW for variables not having 2D, 3D, or 4D shape
134+
if len(variable.shape) not in (2, 3, 4):
136135
return True
137136

138-
# Get variable identifier (use .name in Keras 3+)
137+
# Get variable identifier (.name for logging and checks)
139138
var_identifier = variable.name
140139

141-
# Check if embedding layer should be excluded
140+
# Exclude embedding layers
142141
if self.exclude_embeddings and "embedding" in var_identifier.lower():
143142
return True
144143

145-
# Check if variable matches any excluded layer patterns
146-
for keyword in self.exclude_layers:
144+
# Exclude any user-specified layer patterns
145+
for pattern in self.exclude_layers:
147146
try:
148-
if re.search(keyword, var_identifier):
147+
if re.search(pattern, var_identifier):
149148
return True
150-
except re.error:
149+
except (re.error, TypeError):
150+
# Skip invalid regex patterns or non-string entries
151151
continue
152+
152153
return False
153154

154155
def build(self, var_list):
@@ -166,18 +167,13 @@ def build(self, var_list):
166167
self.adam_momentums = {}
167168
self.adam_velocities = {}
168169

169-
self.muon_momentums = {}
170-
self.muon_velocities = {}
171-
172170
for var in var_list:
173171
if not self._overwrite_variable_with_gradient(var):
174-
self.adam_momentums[var.name] = (
175-
self.add_variable_from_reference(
176-
reference_variable=var, name="momentum"
177-
)
172+
self.adam_momentums[id(var)] = self.add_variable_from_reference(
173+
reference_variable=var, name="momentum"
178174
)
179175
if self._should_use_adamw(var):
180-
self.adam_velocities[var.name] = (
176+
self.adam_velocities[id(var)] = (
181177
self.add_variable_from_reference(
182178
reference_variable=var, name="velocity"
183179
)
@@ -193,7 +189,7 @@ def update_step(self, gradient, variable, learning_rate):
193189
self._muon_update_step(gradient, variable, learning_rate)
194190

195191
def _muon_update_step(self, gradient, variable, lr):
196-
m = self.adam_momentums[variable.name]
192+
m = self.adam_momentums[id(variable)]
197193
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
198194
shape = variable.shape
199195
if self.nesterov:
@@ -220,8 +216,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
220216
ops.cast(self.adam_beta_2, variable.dtype), local_step
221217
)
222218

223-
m = self.adam_momentums[variable.name]
224-
v = self.adam_velocities[variable.name]
219+
m = self.adam_momentums[id(variable)]
220+
v = self.adam_velocities[id(variable)]
225221

226222
alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
227223

keras/src/optimizers/muon_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,19 @@ def test_no_path_attribute_error(self):
9696
self.assertIn(result, [True, False])
9797
except AttributeError as e:
9898
self.fail(f"Unexpected AttributeError: {e}")
99+
100+
def test_variable_name_uniqueness(self):
101+
"""Ensure variable names are unique and no key collisions occur."""
102+
optimizer = Muon()
103+
# Create variables with different names (simulating real layers)
104+
var1 = backend.Variable([[1.0, 2.0]], name="kernel1")
105+
var2 = backend.Variable([[3.0, 4.0]], name="kernel2")
106+
optimizer.build([var1, var2])
107+
108+
# Check that each has its own momentum (unique variable objects)
109+
self.assertIn(id(var1), optimizer.adam_momentums)
110+
self.assertIn(id(var2), optimizer.adam_momentums)
111+
self.assertIsNot(
112+
optimizer.adam_momentums[id(var1)],
113+
optimizer.adam_momentums[id(var2)],
114+
)

0 commit comments

Comments
 (0)