33from keras .src import backend
44from keras .src import ops
55from keras .src import testing
6- from keras .src .layers import Dense
7- from keras .src .layers import Embedding
6+ from keras .src .layers import Dense , Embedding
87from keras .src .optimizers .muon import Muon
98
109
1110class MuonTest (testing .TestCase ):
1211 def test_config (self ):
13- optimizer = Muon (
14- learning_rate = 0.5 ,
15- epsilon = 1e-5 ,
16- )
12+ optimizer = Muon (learning_rate = 0.5 , epsilon = 1e-5 )
1713 self .run_class_serialization_test (optimizer )
1814
1915 def test_Newton_Schulz (self ):
2016 optimizer = Muon ()
2117 tensor_input = ops .array ([[0.2499 , 0.9105 ], [0.2655 , 0.8824 ]])
22- except_output = ops .array ([[- 0.4422 , 0.6457 ], [0.7285 , 0.2968 ]])
18+ expected_output = ops .array ([[- 0.4422 , 0.6457 ], [0.7285 , 0.2968 ]])
2319 output = optimizer .zeropower_via_newtonschulz5 (tensor_input , 5 )
24- self .assertAllClose (output , except_output , rtol = 1e-3 , atol = 1e-3 )
20+ self .assertAllClose (output , expected_output , rtol = 1e-3 , atol = 1e-3 )
2521
2622 def test_adamw_single_step (self ):
2723 optimizer = Muon ()
2824 grads = ops .array ([1.0 , 6.0 , 7.0 , 2.0 ])
29- vars = backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ], name = "test_vars" )
30- optimizer .build ([vars ])
31- optimizer ._adamw_update_step (grads , vars , 0.5 )
32- self .assertAllClose (vars , [0.5 , 1.5 , 2.5 , 3.5 ], rtol = 1e-4 , atol = 1e-4 )
25+ var = backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ], name = "test_vars" )
26+ optimizer .build ([var ])
27+ optimizer ._adamw_update_step (grads , var , 0.5 )
28+ self .assertAllClose (var , [0.5 , 1.5 , 2.5 , 3.5 ], rtol = 1e-4 , atol = 1e-4 )
3329
3430 def test_should_use_adamw (self ):
35- vars = backend .Variable ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
31+ # Excluded layer test
32+ var = backend .Variable ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
3633 optimizer = Muon (exclude_layers = ["var" ])
37- self .assertAllClose (
38- True ,
39- optimizer ._should_use_adamw (vars ),
40- )
41- embeding = Embedding (2 , 2 )
42- embeding .build ()
43- self .assertAllClose (
44- True ,
45- optimizer ._should_use_adamw (embeding .weights [0 ]),
46- )
47- vars = backend .Variable ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
34+ self .assertTrue (optimizer ._should_use_adamw (var ))
35+
36+ # Embedding test
37+ embedding = Embedding (2 , 2 )
38+ embedding .build ()
39+ optimizer = Muon (exclude_embeddings = True )
40+ self .assertTrue (optimizer ._should_use_adamw (embedding .weights [0 ]))
41+
42+ # 2D variable not excluded
43+ var2 = backend .Variable ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
4844 optimizer = Muon ()
49- self .assertAllClose (
50- False ,
51- optimizer ._should_use_adamw (vars ),
52- )
45+ self .assertFalse (optimizer ._should_use_adamw (var2 ))
46+
47+ # Dense layer
5348 dense = Dense (2 )
5449 dense .build ([None , 2 ])
55- self .assertAllClose (
56- False ,
57- optimizer ._should_use_adamw (dense .weights [0 ]),
58- )
50+ self .assertFalse (optimizer ._should_use_adamw (dense .weights [0 ]))
51+
52+ # Dimension rules
53+ v_1d = backend .Variable ([1.0 , 2.0 ], name = "v1d" )
54+ v_5d = backend .Variable (np .zeros ((2 , 2 , 2 , 2 , 2 )), name = "v5d" )
55+ self .assertTrue (optimizer ._should_use_adamw (v_1d ))
56+ self .assertTrue (optimizer ._should_use_adamw (v_5d ))
5957
6058 def test_muon_single_step (self ):
61- optimizer = Muon (
62- learning_rate = 0.5 ,
63- weight_decay = 0 ,
64- )
59+ optimizer = Muon (learning_rate = 0.5 , weight_decay = 0 )
6560 grads = ops .array ([[1.0 , 6.0 ], [7.0 , 2.0 ]])
66- vars = backend .Variable ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
67- optimizer .build ([vars ])
68- optimizer ._muon_update_step (grads , vars , 0.5 )
69- self .assertAllClose (
70- vars , [[1.13 , 1.51 ], [2.57 , 4.06 ]], rtol = 1e-2 , atol = 1e-2
71- )
61+ var = backend .Variable ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
62+ optimizer .build ([var ])
63+ optimizer ._muon_update_step (grads , var , 0.5 )
64+ self .assertAllClose (var , [[1.13 , 1.51 ], [2.57 , 4.06 ]], rtol = 1e-2 , atol = 1e-2 )
7265
7366 def test_clip_norm (self ):
7467 optimizer = Muon (clipnorm = 1 )
@@ -81,3 +74,13 @@ def test_clip_value(self):
8174 grad = [np .array ([100.0 , 100.0 ])]
8275 clipped_grad = optimizer ._clip_gradients (grad )
8376 self .assertAllClose (clipped_grad [0 ], [1.0 , 1.0 ])
77+
78+ def test_no_path_attribute_error (self ):
79+ """Ensure compatibility with TF 2.16+ ResourceVariable (no .path)."""
80+ optimizer = Muon ()
81+ var = backend .Variable ([1.0 , 2.0 ], name = "test_var" )
82+ try :
83+ result = optimizer ._should_use_adamw (var )
84+ self .assertIn (result , [True , False ])
85+ except AttributeError as e :
86+ self .fail (f"Unexpected AttributeError: { e } " )
0 commit comments