Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions keras/src/optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,47 +134,57 @@ def _should_use_adamw(self, variable):
# any {0,1}-D parameters should all be optimized by adam
if not 1 < len(variable.shape) < 4:
return True
if self.exclude_embeddings and "embedding" in variable.path.lower():
# Check .path only during build (where we have keras.Variable)
var_path = variable.path if hasattr(variable, "path") else None
if var_path is None:
return False
if self.exclude_embeddings and "embedding" in var_path.lower():
return True
for keyword in self.exclude_layers:
if re.search(keyword, variable.path):
return True
# Exclude any user-specified layer patterns
for pattern in self.exclude_layers:
try:
if re.search(pattern, var_path):
return True
except (re.error, TypeError):
# Skip invalid regex patterns in exclude_layers
continue
return False

def build(self, var_list):
"""Initialize optimizer variables.

Adam optimizer has 3 types of variables: momentums, velocities and
velocity_hat (only set when amsgrad is applied),
Muon optimizer has 2 types of variables: momentums and velocities.
Velocities are only set when using AdamW update step.

Args:
var_list: list of model variables to build Adam variables on.
var_list: list of model variables to build Muon variables on.
"""
if self.built:
return
super().build(var_list)
self.adam_momentums = {}
self.adam_velocities = {}

self.muon_momentums = {}
self.muon_velocities = {}
# Initialize lists with None for all variables
self.adam_momentums = [None] * len(var_list)
self.adam_velocities = [None] * len(var_list)

for var in var_list:
if not self._overwrite_variable_with_gradient(var):
self.adam_momentums[var.path] = (
var_idx = self._get_variable_index(var)
self.adam_momentums[var_idx] = (
self.add_variable_from_reference(
reference_variable=var, name="momentum"
)
)
if self._should_use_adamw(var):
self.adam_velocities[var.path] = (
self.adam_velocities[var_idx] = (
self.add_variable_from_reference(
reference_variable=var, name="velocity"
)
)

def update_step(self, gradient, variable, learning_rate):
if self._should_use_adamw(variable):
var_idx = self._get_variable_index(variable)
# Check if velocity exists to determine if we should use AdamW
if self.adam_velocities[var_idx] is not None:
# It should be noted that lr is one-tenth when using adamw.
self._adamw_update_step(
gradient, variable, learning_rate * self.adam_lr_ratio
Expand All @@ -183,7 +193,8 @@ def update_step(self, gradient, variable, learning_rate):
self._muon_update_step(gradient, variable, learning_rate)

def _muon_update_step(self, gradient, variable, lr):
m = self.adam_momentums[variable.path]
var_idx = self._get_variable_index(variable)
m = self.adam_momentums[var_idx]
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
shape = variable.shape
if self.nesterov:
Expand All @@ -210,8 +221,9 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
ops.cast(self.adam_beta_2, variable.dtype), local_step
)

m = self.adam_momentums[variable.path]
v = self.adam_velocities[variable.path]
var_idx = self._get_variable_index(variable)
m = self.adam_momentums[var_idx]
v = self.adam_velocities[var_idx]

alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)

Expand Down
16 changes: 1 addition & 15 deletions keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,19 +180,7 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
if file_format is None and isinstance(path, (str, pathlib.Path)):
file_format = pathlib.Path(path).suffix[1:].lower()

# Normalize jpg → jpeg for Pillow compatibility
if file_format and file_format.lower() == "jpg":
file_format = "jpeg"

img = array_to_img(x, data_format=data_format, scale=scale)

# Handle RGBA → RGB conversion for JPEG
if img.mode == "RGBA" and file_format == "jpeg":
warnings.warn(
"The JPEG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")

img.save(path, format=file_format, **kwargs)


Expand Down Expand Up @@ -464,6 +452,4 @@ def smart_resize(
img, size=size, interpolation=interpolation, data_format=data_format
)

if isinstance(x, np.ndarray):
return np.array(img)
return img
return img
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this file

2 changes: 1 addition & 1 deletion keras/src/utils/image_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def test_save_jpg(self, shape, name, file_format, use_explicit_format):
# Verify saved image is correctly converted to RGB if needed
loaded_img = load_img(path)
loaded_array = img_to_array(loaded_img)
self.assertEqual(loaded_array.shape, (50, 50, 3))
self.assertEqual(loaded_array.shape, (50, 50, 3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this file

Loading