From 2b65f890de00288acc58e94683f6a7a8fb21ae2b Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Sat, 4 Oct 2025 18:00:36 +0545 Subject: [PATCH 1/6] =?UTF-8?q?Simplify=20save=5Fimg:=20remove=20=5Fformat?= =?UTF-8?q?,=20normalize=20jpg=E2=86=92jpeg,=20add=20RGBA=E2=86=92RGB=20ha?= =?UTF-8?q?ndling=20and=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- integration_tests/test_save_img.py | 27 +++++++++++++++++++++++++++ keras/src/utils/image_utils.py | 7 +++++-- 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 integration_tests/test_save_img.py diff --git a/integration_tests/test_save_img.py b/integration_tests/test_save_img.py new file mode 100644 index 000000000000..baec2712bfc2 --- /dev/null +++ b/integration_tests/test_save_img.py @@ -0,0 +1,27 @@ +import os + +import numpy as np +import pytest + +from keras.utils import img_to_array +from keras.utils import load_img +from keras.utils import save_img + + +@pytest.mark.parametrize( + "shape, name", + [ + ((50, 50, 3), "rgb.jpg"), + ((50, 50, 4), "rgba.jpg"), + ], +) +def test_save_jpg(tmp_path, shape, name): + img = np.random.randint(0, 256, size=shape, dtype=np.uint8) + path = tmp_path / name + save_img(path, img, file_format="jpg") + assert os.path.exists(path) + + # Check that the image was saved correctly and converted to RGB if needed. + loaded_img = load_img(path) + loaded_array = img_to_array(loaded_img) + assert loaded_array.shape == (50, 50, 3) \ No newline at end of file diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index ca8289c9f9b7..a8781a0f46ae 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ data_format = backend.standardize_data_format(data_format) + # Normalize jpg → jpeg + if file_format is not None and file_format.lower() == "jpg": + file_format = "jpeg" img = array_to_img(x, data_format=data_format, scale=scale) - if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): + if img.mode == "RGBA" and file_format == "jpeg": warnings.warn( - "The JPG format does not support RGBA images, converting to RGB." + "The JPEG format does not support RGBA images, converting to RGB." ) img = img.convert("RGB") img.save(path, format=file_format, **kwargs) From 6896d1c8e280bc5378738c9b2c771f520cfb3496 Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Sat, 25 Oct 2025 11:11:28 +0545 Subject: [PATCH 2/6] Fix: correct effective crop_to_aspect_ratio logic and lint issues --- .../image_preprocessing/resizing.py | 15 ++++++++++- .../image_preprocessing/resizing_test.py | 25 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing.py b/keras/src/layers/preprocessing/image_preprocessing/resizing.py index 83460175ee54..5a5c80c310d0 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/resizing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing.py @@ -101,6 +101,19 @@ def __init__( self.width_axis = -2 def transform_images(self, images, transformation=None, training=True): + # Compute effective crop flag: + # only crop if aspect ratios differ and flag is True + input_height, input_width = transformation + source_aspect_ratio = input_width / input_height + target_aspect_ratio = self.width / self.height + # Use a small epsilon for floating-point comparison + aspect_ratios_match = ( + abs(source_aspect_ratio - target_aspect_ratio) < 1e-6 + ) + effective_crop_to_aspect_ratio = ( + self.crop_to_aspect_ratio and not aspect_ratios_match + ) + size = (self.height, self.width) resized = self.backend.image.resize( images, @@ -108,7 +121,7 @@ def transform_images(self, images, transformation=None, training=True): interpolation=self.interpolation, antialias=self.antialias, data_format=self.data_format, - crop_to_aspect_ratio=self.crop_to_aspect_ratio, + crop_to_aspect_ratio=effective_crop_to_aspect_ratio, pad_to_aspect_ratio=self.pad_to_aspect_ratio, fill_mode=self.fill_mode, fill_value=self.fill_value, diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py b/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py index 38dfafbeaab0..6f465adb17ac 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py @@ -124,6 +124,31 @@ def test_crop_to_aspect_ratio(self, data_format): ref_out = ref_out.transpose(0, 3, 1, 2) self.assertAllClose(ref_out, out) + @parameterized.parameters([("channels_first",), ("channels_last",)]) + def test_crop_to_aspect_ratio_no_op_when_aspects_match(self, data_format): + # Test that crop_to_aspect_ratio=True behaves identically to False + # when source and target aspect ratios match (no cropping should occur). + img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype("float32") + if data_format == "channels_first": + img = img.transpose(0, 3, 1, 2) + out_false = layers.Resizing( + height=2, + width=2, + interpolation="nearest", + data_format=data_format, + crop_to_aspect_ratio=False, + )(img) + out_true = layers.Resizing( + height=2, + width=2, + interpolation="nearest", + data_format=data_format, + crop_to_aspect_ratio=True, + )(img) + # Outputs should be identical when aspect ratios match + # (4:4 -> 2:2, both 1:1). + self.assertAllClose(out_false, out_true) + @parameterized.parameters([("channels_first",), ("channels_last",)]) def test_unbatched_image(self, data_format): img = np.reshape(np.arange(0, 16), (4, 4, 1)).astype("float32") From 61bcaeb6dcb1d74d671c5acb8677d3b80d2df78b Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Sat, 25 Oct 2025 18:50:42 +0545 Subject: [PATCH 3/6] Fix: correct effective crop_to_aspect_ratio logic and lint issues --- integration_tests/test_save_img.py | 27 --------------------------- keras/src/utils/image_utils.py | 9 +++------ 2 files changed, 3 insertions(+), 33 deletions(-) delete mode 100644 integration_tests/test_save_img.py diff --git a/integration_tests/test_save_img.py b/integration_tests/test_save_img.py deleted file mode 100644 index baec2712bfc2..000000000000 --- a/integration_tests/test_save_img.py +++ /dev/null @@ -1,27 +0,0 @@ -import os - -import numpy as np -import pytest - -from keras.utils import img_to_array -from keras.utils import load_img -from keras.utils import save_img - - -@pytest.mark.parametrize( - "shape, name", - [ - ((50, 50, 3), "rgb.jpg"), - ((50, 50, 4), "rgba.jpg"), - ], -) -def test_save_jpg(tmp_path, shape, name): - img = np.random.randint(0, 256, size=shape, dtype=np.uint8) - path = tmp_path / name - save_img(path, img, file_format="jpg") - assert os.path.exists(path) - - # Check that the image was saved correctly and converted to RGB if needed. - loaded_img = load_img(path) - loaded_array = img_to_array(loaded_img) - assert loaded_array.shape == (50, 50, 3) \ No newline at end of file diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index a8781a0f46ae..a19a8519f021 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -175,13 +175,10 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ data_format = backend.standardize_data_format(data_format) - # Normalize jpg → jpeg - if file_format is not None and file_format.lower() == "jpg": - file_format = "jpeg" img = array_to_img(x, data_format=data_format, scale=scale) - if img.mode == "RGBA" and file_format == "jpeg": + if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): warnings.warn( - "The JPEG format does not support RGBA images, converting to RGB." + "The JPG format does not support RGBA images, converting to RGB." ) img = img.convert("RGB") img.save(path, format=file_format, **kwargs) @@ -457,4 +454,4 @@ def smart_resize( if isinstance(x, np.ndarray): return np.array(img) - return img + return img \ No newline at end of file From 6c3e06973030a8ded213040a3798d4ae42637395 Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Sat, 25 Oct 2025 18:51:34 +0545 Subject: [PATCH 4/6] Fix: correct effective crop_to_aspect_ratio logic and lint issues --- keras/src/utils/image_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index a19a8519f021..ca8289c9f9b7 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -454,4 +454,4 @@ def smart_resize( if isinstance(x, np.ndarray): return np.array(img) - return img \ No newline at end of file + return img From a809ba6fa7a81e5ad35f4444f655aa1a312f2246 Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Thu, 6 Nov 2025 17:58:41 +0545 Subject: [PATCH 5/6] Add epsilon to prevent ZeroDivisionError in crop_to_aspect_ratio calculation --- .../layers/preprocessing/image_preprocessing/resizing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing.py b/keras/src/layers/preprocessing/image_preprocessing/resizing.py index 5a5c80c310d0..977fa4cf7694 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/resizing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing.py @@ -104,8 +104,10 @@ def transform_images(self, images, transformation=None, training=True): # Compute effective crop flag: # only crop if aspect ratios differ and flag is True input_height, input_width = transformation - source_aspect_ratio = input_width / input_height - target_aspect_ratio = self.width / self.height + epsilon = self.backend.epsilon() + source_aspect_ratio = input_width / (input_height + epsilon) + target_aspect_ratio = self.width / (self.height + epsilon) + # Use a small epsilon for floating-point comparison aspect_ratios_match = ( abs(source_aspect_ratio - target_aspect_ratio) < 1e-6 From 5ca7a54f900d82439caf02b8998a5660161b79e3 Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Tue, 11 Nov 2025 11:28:51 +0545 Subject: [PATCH 6/6] Fix: Add aspect ratio check and epsilon to prevent ZeroDivisionError Changes: - Add epsilon (1e-6) to prevent division by zero in all backend resize functions - Only crop when aspect ratios differ (with epsilon tolerance) - Skip cropping when source and target aspect ratios match - Add epsilon to _transform_boxes_crop_to_aspect_ratio in resizing.py Fixes failing tests by ensuring crop_to_aspect_ratio=True behaves identically to False when aspect ratios already match. Addresses reviewer feedback on PR #21779 --- keras/src/backend/jax/image.py | 27 ++++-- keras/src/backend/numpy/image.py | 27 ++++-- keras/src/backend/tensorflow/image.py | 85 ++++++++++++------- keras/src/backend/torch/image.py | 27 ++++-- .../image_preprocessing/resizing.py | 16 ++-- 5 files changed, 125 insertions(+), 57 deletions(-) diff --git a/keras/src/backend/jax/image.py b/keras/src/backend/jax/image.py index 52e37eed6c45..64531b68ec5d 100644 --- a/keras/src/backend/jax/image.py +++ b/keras/src/backend/jax/image.py @@ -216,12 +216,27 @@ def resize( height, width = shape[-3], shape[-2] else: height, width = shape[-2], shape[-1] - crop_height = int(float(width * target_height) / target_width) - crop_height = max(min(height, crop_height), 1) - crop_width = int(float(height * target_width) / target_height) - crop_width = max(min(width, crop_width), 1) - crop_box_hstart = int(float(height - crop_height) / 2) - crop_box_wstart = int(float(width - crop_width) / 2) + + # Add epsilon to prevent division by zero + epsilon = 1e-6 + source_aspect_ratio = float(width) / (float(height) + epsilon) + target_aspect_ratio = float(target_width) / (float(target_height) + epsilon) + + # Only crop if aspect ratios differ (with epsilon tolerance) + aspect_ratio_diff = abs(source_aspect_ratio - target_aspect_ratio) + if aspect_ratio_diff > epsilon: + crop_height = int(float(width * target_height) / (target_width + epsilon)) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / (target_height + epsilon)) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + else: + # Skip cropping when aspect ratios match + crop_box_hstart = 0 + crop_box_wstart = 0 + crop_height = height + crop_width = width if data_format == "channels_last": if len(images.shape) == 4: images = images[ diff --git a/keras/src/backend/numpy/image.py b/keras/src/backend/numpy/image.py index 30ce1c9bba4c..0dfb93ead471 100644 --- a/keras/src/backend/numpy/image.py +++ b/keras/src/backend/numpy/image.py @@ -212,12 +212,27 @@ def resize( height, width = shape[-3], shape[-2] else: height, width = shape[-2], shape[-1] - crop_height = int(float(width * target_height) / target_width) - crop_height = max(min(height, crop_height), 1) - crop_width = int(float(height * target_width) / target_height) - crop_width = max(min(width, crop_width), 1) - crop_box_hstart = int(float(height - crop_height) / 2) - crop_box_wstart = int(float(width - crop_width) / 2) + + # Add epsilon to prevent division by zero + epsilon = 1e-6 + source_aspect_ratio = float(width) / (float(height) + epsilon) + target_aspect_ratio = float(target_width) / (float(target_height) + epsilon) + + # Only crop if aspect ratios differ (with epsilon tolerance) + aspect_ratio_diff = abs(source_aspect_ratio - target_aspect_ratio) + if aspect_ratio_diff > epsilon: + crop_height = int(float(width * target_height) / (target_width + epsilon)) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / (target_height + epsilon)) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + else: + # Skip cropping when aspect ratios match + crop_box_hstart = 0 + crop_box_wstart = 0 + crop_height = height + crop_width = width if data_format == "channels_last": if len(images.shape) == 4: images = images[ diff --git a/keras/src/backend/tensorflow/image.py b/keras/src/backend/tensorflow/image.py index 0c693f4ff243..8e360c2d921d 100644 --- a/keras/src/backend/tensorflow/image.py +++ b/keras/src/backend/tensorflow/image.py @@ -177,40 +177,59 @@ def resize( if crop_to_aspect_ratio: shape = tf.shape(images) - height, width = shape[-3], shape[-2] - target_height, target_width = size - crop_height = tf.cast( - tf.cast(width * target_height, "float32") / target_width, - "int32", - ) - crop_height = tf.maximum(tf.minimum(height, crop_height), 1) - crop_height = tf.cast(crop_height, "int32") - crop_width = tf.cast( - tf.cast(height * target_width, "float32") / target_height, - "int32", - ) - crop_width = tf.maximum(tf.minimum(width, crop_width), 1) - crop_width = tf.cast(crop_width, "int32") + height = tf.cast(shape[-3], "float32") + width = tf.cast(shape[-2], "float32") + target_height = tf.cast(size[0], "float32") + target_width = tf.cast(size[1], "float32") + + # Add epsilon to prevent division by zero + epsilon = tf.constant(1e-6, dtype="float32") + source_aspect_ratio = width / (height + epsilon) + target_aspect_ratio = target_width / (target_height + epsilon) + + # Only crop if aspect ratios differ (with epsilon tolerance) + aspect_ratio_diff = tf.abs(source_aspect_ratio - target_aspect_ratio) + should_crop = aspect_ratio_diff > epsilon + + def apply_crop(): + crop_height = tf.cast( + tf.cast(width * target_height, "float32") / (target_width + epsilon), + "int32", + ) + crop_height = tf.maximum( + tf.minimum(tf.cast(height, "int32"), crop_height), 1 + ) + crop_height = tf.cast(crop_height, "int32") + crop_width = tf.cast( + tf.cast(height * target_width, "float32") / (target_height + epsilon), + "int32", + ) + crop_width = tf.maximum( + tf.minimum(tf.cast(width, "int32"), crop_width), 1 + ) + crop_width = tf.cast(crop_width, "int32") - crop_box_hstart = tf.cast( - tf.cast(height - crop_height, "float32") / 2, "int32" - ) - crop_box_wstart = tf.cast( - tf.cast(width - crop_width, "float32") / 2, "int32" - ) - if len(images.shape) == 4: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - images = images[ - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] + crop_box_hstart = tf.cast( + tf.cast(tf.cast(height, "int32") - crop_height, "float32") / 2, "int32" + ) + crop_box_wstart = tf.cast( + tf.cast(tf.cast(width, "int32") - crop_width, "float32") / 2, "int32" + ) + if len(images.shape) == 4: + return images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + return images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + + images = tf.cond(should_crop, apply_crop, lambda: images) elif pad_to_aspect_ratio: shape = tf.shape(images) height, width = shape[-3], shape[-2] diff --git a/keras/src/backend/torch/image.py b/keras/src/backend/torch/image.py index b6976dc8569a..5f0a0f8f1ae6 100644 --- a/keras/src/backend/torch/image.py +++ b/keras/src/backend/torch/image.py @@ -253,12 +253,27 @@ def resize( shape = images.shape height, width = shape[-2], shape[-1] target_height, target_width = size - crop_height = int(float(width * target_height) / target_width) - crop_height = max(min(height, crop_height), 1) - crop_width = int(float(height * target_width) / target_height) - crop_width = max(min(width, crop_width), 1) - crop_box_hstart = int(float(height - crop_height) / 2) - crop_box_wstart = int(float(width - crop_width) / 2) + + # Add epsilon to prevent division by zero + epsilon = 1e-6 + source_aspect_ratio = float(width) / (float(height) + epsilon) + target_aspect_ratio = float(target_width) / (float(target_height) + epsilon) + + # Only crop if aspect ratios differ (with epsilon tolerance) + aspect_ratio_diff = abs(source_aspect_ratio - target_aspect_ratio) + if aspect_ratio_diff > epsilon: + crop_height = int(float(width * target_height) / (target_width + epsilon)) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / (target_height + epsilon)) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + else: + # Skip cropping when aspect ratios match + crop_box_hstart = 0 + crop_box_wstart = 0 + crop_height = height + crop_width = width images = images[ :, :, diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing.py b/keras/src/layers/preprocessing/image_preprocessing/resizing.py index 977fa4cf7694..c51030559138 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/resizing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing.py @@ -248,17 +248,21 @@ def _transform_boxes_crop_to_aspect_ratio( ): """Transforms bounding boxes for cropping to aspect ratio.""" ops = self.backend - source_aspect_ratio = input_width / input_height - target_aspect_ratio = self.width / self.height + # Add epsilon to prevent division by zero + epsilon = ops.cast(ops.epsilon(), dtype=boxes.dtype) + source_aspect_ratio = input_width / (input_height + epsilon) + target_aspect_ratio = ops.cast( + self.width / (self.height + epsilon), dtype=boxes.dtype + ) new_width = ops.numpy.where( source_aspect_ratio > target_aspect_ratio, - self.height * source_aspect_ratio, - self.width, + ops.cast(self.height, dtype=boxes.dtype) * source_aspect_ratio, + ops.cast(self.width, dtype=boxes.dtype), ) new_height = ops.numpy.where( source_aspect_ratio > target_aspect_ratio, - self.height, - self.width / source_aspect_ratio, + ops.cast(self.height, dtype=boxes.dtype), + ops.cast(self.width, dtype=boxes.dtype) / source_aspect_ratio, ) scale_x = new_width / input_width scale_y = new_height / input_height