Skip to content

Commit d9e0c3a

Browse files
committed
Fix: Remove deprecated .path access in Muon optimizer for TF 2.16+ compatibility
1 parent 7c052bc commit d9e0c3a

File tree

2 files changed

+74
-55
lines changed

2 files changed

+74
-55
lines changed

keras/src/optimizers/muon.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,35 @@ def __init__(
128128
self.exclude_embeddings = exclude_embeddings
129129
self.exclude_layers = exclude_layers or []
130130

131+
131132
def _should_use_adamw(self, variable):
132-
# To use it with 4D convolutional filters,
133-
# it works well to just flatten their last 3 dimensions.
134-
# any {0,1}-D parameters should all be optimized by adam
135-
if not 1 < len(variable.shape) < 4:
133+
"""
134+
To use it with 4D convolutional filters,
135+
it works well to just flatten their last 3 dimensions.
136+
any {0,1}-D parameters should all be optimized by adam
137+
"""
138+
# Use Adam for scalar or vector parameters
139+
if not 1 < len(variable.shape) <5:
136140
return True
137-
if self.exclude_embeddings and "embedding" in variable.path.lower():
141+
142+
# Exclude embedding layers if specified
143+
var_identifier = getattr(variable, "name", "") or getattr(variable, "path", "")
144+
if self.exclude_embeddings and "embedding" in var_identifier.lower():
138145
return True
139-
for keyword in self.exclude_layers:
140-
if re.search(keyword, variable.path):
141-
return True
146+
147+
# Exclude variables matching any of the excluded layer patterns
148+
for keyword in getattr(self, "exclude_layers", []):
149+
try:
150+
if re.search(keyword, var_identifier):
151+
return True
152+
except re.error:
153+
# Skip invalid regex patterns
154+
continue
155+
156+
# Otherwise, use AdamW
142157
return False
143158

159+
144160
def build(self, var_list):
145161
"""Initialize optimizer variables.
146162
@@ -161,13 +177,13 @@ def build(self, var_list):
161177

162178
for var in var_list:
163179
if not self._overwrite_variable_with_gradient(var):
164-
self.adam_momentums[var.path] = (
180+
self.adam_momentums[var.name] = (
165181
self.add_variable_from_reference(
166182
reference_variable=var, name="momentum"
167183
)
168184
)
169185
if self._should_use_adamw(var):
170-
self.adam_velocities[var.path] = (
186+
self.adam_velocities[var.name] = (
171187
self.add_variable_from_reference(
172188
reference_variable=var, name="velocity"
173189
)
@@ -183,7 +199,7 @@ def update_step(self, gradient, variable, learning_rate):
183199
self._muon_update_step(gradient, variable, learning_rate)
184200

185201
def _muon_update_step(self, gradient, variable, lr):
186-
m = self.adam_momentums[variable.path]
202+
m = self.adam_momentums[variable.name]
187203
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
188204
shape = variable.shape
189205
if self.nesterov:
@@ -210,8 +226,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
210226
ops.cast(self.adam_beta_2, variable.dtype), local_step
211227
)
212228

213-
m = self.adam_momentums[variable.path]
214-
v = self.adam_velocities[variable.path]
229+
m = self.adam_momentums[variable.name]
230+
v = self.adam_velocities[variable.name]
215231

216232
alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
217233

keras/src/optimizers/muon_test.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,65 @@
33
from keras.src import backend
44
from keras.src import ops
55
from keras.src import testing
6-
from keras.src.layers import Dense
7-
from keras.src.layers import Embedding
6+
from keras.src.layers import Dense, Embedding
87
from keras.src.optimizers.muon import Muon
98

109

1110
class MuonTest(testing.TestCase):
1211
def test_config(self):
13-
optimizer = Muon(
14-
learning_rate=0.5,
15-
epsilon=1e-5,
16-
)
12+
optimizer = Muon(learning_rate=0.5, epsilon=1e-5)
1713
self.run_class_serialization_test(optimizer)
1814

1915
def test_Newton_Schulz(self):
2016
optimizer = Muon()
2117
tensor_input = ops.array([[0.2499, 0.9105], [0.2655, 0.8824]])
22-
except_output = ops.array([[-0.4422, 0.6457], [0.7285, 0.2968]])
18+
expected_output = ops.array([[-0.4422, 0.6457], [0.7285, 0.2968]])
2319
output = optimizer.zeropower_via_newtonschulz5(tensor_input, 5)
24-
self.assertAllClose(output, except_output, rtol=1e-3, atol=1e-3)
20+
self.assertAllClose(output, expected_output, rtol=1e-3, atol=1e-3)
2521

2622
def test_adamw_single_step(self):
2723
optimizer = Muon()
2824
grads = ops.array([1.0, 6.0, 7.0, 2.0])
29-
vars = backend.Variable([1.0, 2.0, 3.0, 4.0], name="test_vars")
30-
optimizer.build([vars])
31-
optimizer._adamw_update_step(grads, vars, 0.5)
32-
self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
25+
var = backend.Variable([1.0, 2.0, 3.0, 4.0], name="test_vars")
26+
optimizer.build([var])
27+
optimizer._adamw_update_step(grads, var, 0.5)
28+
self.assertAllClose(var, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
3329

3430
def test_should_use_adamw(self):
35-
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
31+
# Excluded layer test
32+
var = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
3633
optimizer = Muon(exclude_layers=["var"])
37-
self.assertAllClose(
38-
True,
39-
optimizer._should_use_adamw(vars),
40-
)
41-
embeding = Embedding(2, 2)
42-
embeding.build()
43-
self.assertAllClose(
44-
True,
45-
optimizer._should_use_adamw(embeding.weights[0]),
46-
)
47-
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
34+
self.assertTrue(optimizer._should_use_adamw(var))
35+
36+
# Embedding test
37+
embedding = Embedding(2, 2)
38+
embedding.build()
39+
optimizer = Muon(exclude_embeddings=True)
40+
self.assertTrue(optimizer._should_use_adamw(embedding.weights[0]))
41+
42+
# 2D variable not excluded
43+
var2 = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
4844
optimizer = Muon()
49-
self.assertAllClose(
50-
False,
51-
optimizer._should_use_adamw(vars),
52-
)
45+
self.assertFalse(optimizer._should_use_adamw(var2))
46+
47+
# Dense layer
5348
dense = Dense(2)
5449
dense.build([None, 2])
55-
self.assertAllClose(
56-
False,
57-
optimizer._should_use_adamw(dense.weights[0]),
58-
)
50+
self.assertFalse(optimizer._should_use_adamw(dense.weights[0]))
51+
52+
# Dimension rules
53+
v_1d = backend.Variable([1.0, 2.0], name="v1d")
54+
v_5d = backend.Variable(np.zeros((2, 2, 2, 2, 2)), name="v5d")
55+
self.assertTrue(optimizer._should_use_adamw(v_1d))
56+
self.assertTrue(optimizer._should_use_adamw(v_5d))
5957

6058
def test_muon_single_step(self):
61-
optimizer = Muon(
62-
learning_rate=0.5,
63-
weight_decay=0,
64-
)
59+
optimizer = Muon(learning_rate=0.5, weight_decay=0)
6560
grads = ops.array([[1.0, 6.0], [7.0, 2.0]])
66-
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
67-
optimizer.build([vars])
68-
optimizer._muon_update_step(grads, vars, 0.5)
69-
self.assertAllClose(
70-
vars, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2
71-
)
61+
var = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
62+
optimizer.build([var])
63+
optimizer._muon_update_step(grads, var, 0.5)
64+
self.assertAllClose(var, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2)
7265

7366
def test_clip_norm(self):
7467
optimizer = Muon(clipnorm=1)
@@ -81,3 +74,13 @@ def test_clip_value(self):
8174
grad = [np.array([100.0, 100.0])]
8275
clipped_grad = optimizer._clip_gradients(grad)
8376
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
77+
78+
def test_no_path_attribute_error(self):
79+
"""Ensure compatibility with TF 2.16+ ResourceVariable (no .path)."""
80+
optimizer = Muon()
81+
var = backend.Variable([1.0, 2.0], name="test_var")
82+
try:
83+
result = optimizer._should_use_adamw(var)
84+
self.assertIn(result, [True, False])
85+
except AttributeError as e:
86+
self.fail(f"Unexpected AttributeError: {e}")

0 commit comments

Comments
 (0)