diff --git a/keras/src/optimizers/muon.py b/keras/src/optimizers/muon.py index 88d0dde3ee92..740cf548afa4 100644 --- a/keras/src/optimizers/muon.py +++ b/keras/src/optimizers/muon.py @@ -134,47 +134,57 @@ def _should_use_adamw(self, variable): # any {0,1}-D parameters should all be optimized by adam if not 1 < len(variable.shape) < 4: return True - if self.exclude_embeddings and "embedding" in variable.path.lower(): + # Check .path only during build (where we have keras.Variable) + var_path = variable.path if hasattr(variable, "path") else None + if var_path is None: + return False + if self.exclude_embeddings and "embedding" in var_path.lower(): return True - for keyword in self.exclude_layers: - if re.search(keyword, variable.path): - return True + # Exclude any user-specified layer patterns + for pattern in self.exclude_layers: + try: + if re.search(pattern, var_path): + return True + except (re.error, TypeError): + # Skip invalid regex patterns in exclude_layers + continue return False def build(self, var_list): """Initialize optimizer variables. - Adam optimizer has 3 types of variables: momentums, velocities and - velocity_hat (only set when amsgrad is applied), + Muon optimizer has 2 types of variables: momentums and velocities. + Velocities are only set when using AdamW update step. Args: - var_list: list of model variables to build Adam variables on. + var_list: list of model variables to build Muon variables on. """ if self.built: return super().build(var_list) - self.adam_momentums = {} - self.adam_velocities = {} - - self.muon_momentums = {} - self.muon_velocities = {} + # Initialize lists with None for all variables + self.adam_momentums = [None] * len(var_list) + self.adam_velocities = [None] * len(var_list) for var in var_list: if not self._overwrite_variable_with_gradient(var): - self.adam_momentums[var.path] = ( + var_idx = self._get_variable_index(var) + self.adam_momentums[var_idx] = ( self.add_variable_from_reference( reference_variable=var, name="momentum" ) ) if self._should_use_adamw(var): - self.adam_velocities[var.path] = ( + self.adam_velocities[var_idx] = ( self.add_variable_from_reference( reference_variable=var, name="velocity" ) ) def update_step(self, gradient, variable, learning_rate): - if self._should_use_adamw(variable): + var_idx = self._get_variable_index(variable) + # Check if velocity exists to determine if we should use AdamW + if self.adam_velocities[var_idx] is not None: # It should be noted that lr is one-tenth when using adamw. self._adamw_update_step( gradient, variable, learning_rate * self.adam_lr_ratio @@ -183,7 +193,8 @@ def update_step(self, gradient, variable, learning_rate): self._muon_update_step(gradient, variable, learning_rate) def _muon_update_step(self, gradient, variable, lr): - m = self.adam_momentums[variable.path] + var_idx = self._get_variable_index(variable) + m = self.adam_momentums[var_idx] self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) shape = variable.shape if self.nesterov: @@ -210,8 +221,9 @@ def _adamw_update_step(self, gradient, variable, learning_rate): ops.cast(self.adam_beta_2, variable.dtype), local_step ) - m = self.adam_momentums[variable.path] - v = self.adam_velocities[variable.path] + var_idx = self._get_variable_index(variable) + m = self.adam_momentums[var_idx] + v = self.adam_velocities[var_idx] alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power) diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index f5bb63a5421c..2dc36d482ade 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -180,19 +180,7 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): if file_format is None and isinstance(path, (str, pathlib.Path)): file_format = pathlib.Path(path).suffix[1:].lower() - # Normalize jpg → jpeg for Pillow compatibility - if file_format and file_format.lower() == "jpg": - file_format = "jpeg" - img = array_to_img(x, data_format=data_format, scale=scale) - - # Handle RGBA → RGB conversion for JPEG - if img.mode == "RGBA" and file_format == "jpeg": - warnings.warn( - "The JPEG format does not support RGBA images, converting to RGB." - ) - img = img.convert("RGB") - img.save(path, format=file_format, **kwargs) @@ -464,6 +452,4 @@ def smart_resize( img, size=size, interpolation=interpolation, data_format=data_format ) - if isinstance(x, np.ndarray): - return np.array(img) - return img + return img \ No newline at end of file diff --git a/keras/src/utils/image_utils_test.py b/keras/src/utils/image_utils_test.py index 31fb30cf83c9..859cd267e0d1 100644 --- a/keras/src/utils/image_utils_test.py +++ b/keras/src/utils/image_utils_test.py @@ -33,4 +33,4 @@ def test_save_jpg(self, shape, name, file_format, use_explicit_format): # Verify saved image is correctly converted to RGB if needed loaded_img = load_img(path) loaded_array = img_to_array(loaded_img) - self.assertEqual(loaded_array.shape, (50, 50, 3)) + self.assertEqual(loaded_array.shape, (50, 50, 3)) \ No newline at end of file