Skip to content

Commit ee582b8

Browse files
committed
Fix: Remove deprecated .path access in Muon optimizer for TF 2.16+ compatibility
1 parent d9e0c3a commit ee582b8

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

keras/src/optimizers/muon.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,20 @@ def __init__(
128128
self.exclude_embeddings = exclude_embeddings
129129
self.exclude_layers = exclude_layers or []
130130

131-
132131
def _should_use_adamw(self, variable):
133132
"""
134133
To use it with 4D convolutional filters,
135134
it works well to just flatten their last 3 dimensions.
136-
any {0,1}-D parameters should all be optimized by adam
135+
any {0,1}-D parameters should all be optimized by adam
137136
"""
138137
# Use Adam for scalar or vector parameters
139-
if not 1 < len(variable.shape) <5:
138+
if not 1 < len(variable.shape) < 5:
140139
return True
141140

142141
# Exclude embedding layers if specified
143-
var_identifier = getattr(variable, "name", "") or getattr(variable, "path", "")
142+
var_identifier = getattr(variable, "name", "") or getattr(
143+
variable, "path", ""
144+
)
144145
if self.exclude_embeddings and "embedding" in var_identifier.lower():
145146
return True
146147

@@ -156,7 +157,6 @@ def _should_use_adamw(self, variable):
156157
# Otherwise, use AdamW
157158
return False
158159

159-
160160
def build(self, var_list):
161161
"""Initialize optimizer variables.
162162

keras/src/optimizers/muon_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from keras.src import backend
44
from keras.src import ops
55
from keras.src import testing
6-
from keras.src.layers import Dense, Embedding
6+
from keras.src.layers import Dense
7+
from keras.src.layers import Embedding
78
from keras.src.optimizers.muon import Muon
89

910

@@ -61,7 +62,9 @@ def test_muon_single_step(self):
6162
var = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
6263
optimizer.build([var])
6364
optimizer._muon_update_step(grads, var, 0.5)
64-
self.assertAllClose(var, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2)
65+
self.assertAllClose(
66+
var, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2
67+
)
6568

6669
def test_clip_norm(self):
6770
optimizer = Muon(clipnorm=1)
@@ -76,7 +79,8 @@ def test_clip_value(self):
7679
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
7780

7881
def test_no_path_attribute_error(self):
79-
"""Ensure compatibility with TF 2.16+ ResourceVariable (no .path)."""
82+
"""Ensure compatibility with TF 2.16+
83+
ResourceVariable (no .path)."""
8084
optimizer = Muon()
8185
var = backend.Variable([1.0, 2.0], name="test_var")
8286
try:

0 commit comments

Comments
 (0)