Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions keras/src/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,27 @@ def resize(
height, width = shape[-3], shape[-2]
else:
height, width = shape[-2], shape[-1]
crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

# Add epsilon to prevent division by zero
epsilon = 1e-6
source_aspect_ratio = float(width) / (float(height) + epsilon)
target_aspect_ratio = float(target_width) / (float(target_height) + epsilon)

# Only crop if aspect ratios differ (with epsilon tolerance)
aspect_ratio_diff = abs(source_aspect_ratio - target_aspect_ratio)
if aspect_ratio_diff > epsilon:
crop_height = int(float(width * target_height) / (target_width + epsilon))
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / (target_height + epsilon))
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
else:
# Skip cropping when aspect ratios match
crop_box_hstart = 0
crop_box_wstart = 0
crop_height = height
crop_width = width
if data_format == "channels_last":
if len(images.shape) == 4:
images = images[
Expand Down
27 changes: 21 additions & 6 deletions keras/src/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,27 @@ def resize(
height, width = shape[-3], shape[-2]
else:
height, width = shape[-2], shape[-1]
crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

# Add epsilon to prevent division by zero
epsilon = 1e-6
source_aspect_ratio = float(width) / (float(height) + epsilon)
target_aspect_ratio = float(target_width) / (float(target_height) + epsilon)

# Only crop if aspect ratios differ (with epsilon tolerance)
aspect_ratio_diff = abs(source_aspect_ratio - target_aspect_ratio)
if aspect_ratio_diff > epsilon:
crop_height = int(float(width * target_height) / (target_width + epsilon))
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / (target_height + epsilon))
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
else:
# Skip cropping when aspect ratios match
crop_box_hstart = 0
crop_box_wstart = 0
crop_height = height
crop_width = width
if data_format == "channels_last":
if len(images.shape) == 4:
images = images[
Expand Down
85 changes: 52 additions & 33 deletions keras/src/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,40 +177,59 @@ def resize(

if crop_to_aspect_ratio:
shape = tf.shape(images)
height, width = shape[-3], shape[-2]
target_height, target_width = size
crop_height = tf.cast(
tf.cast(width * target_height, "float32") / target_width,
"int32",
)
crop_height = tf.maximum(tf.minimum(height, crop_height), 1)
crop_height = tf.cast(crop_height, "int32")
crop_width = tf.cast(
tf.cast(height * target_width, "float32") / target_height,
"int32",
)
crop_width = tf.maximum(tf.minimum(width, crop_width), 1)
crop_width = tf.cast(crop_width, "int32")
height = tf.cast(shape[-3], "float32")
width = tf.cast(shape[-2], "float32")
target_height = tf.cast(size[0], "float32")
target_width = tf.cast(size[1], "float32")

# Add epsilon to prevent division by zero
epsilon = tf.constant(1e-6, dtype="float32")
source_aspect_ratio = width / (height + epsilon)
target_aspect_ratio = target_width / (target_height + epsilon)

# Only crop if aspect ratios differ (with epsilon tolerance)
aspect_ratio_diff = tf.abs(source_aspect_ratio - target_aspect_ratio)
should_crop = aspect_ratio_diff > epsilon

def apply_crop():
crop_height = tf.cast(
tf.cast(width * target_height, "float32") / (target_width + epsilon),
"int32",
)
crop_height = tf.maximum(
tf.minimum(tf.cast(height, "int32"), crop_height), 1
)
crop_height = tf.cast(crop_height, "int32")
crop_width = tf.cast(
tf.cast(height * target_width, "float32") / (target_height + epsilon),
"int32",
)
crop_width = tf.maximum(
tf.minimum(tf.cast(width, "int32"), crop_width), 1
)
crop_width = tf.cast(crop_width, "int32")

crop_box_hstart = tf.cast(
tf.cast(height - crop_height, "float32") / 2, "int32"
)
crop_box_wstart = tf.cast(
tf.cast(width - crop_width, "float32") / 2, "int32"
)
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
crop_box_hstart = tf.cast(
tf.cast(tf.cast(height, "int32") - crop_height, "float32") / 2, "int32"
)
crop_box_wstart = tf.cast(
tf.cast(tf.cast(width, "int32") - crop_width, "float32") / 2, "int32"
)
if len(images.shape) == 4:
return images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
return images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]

images = tf.cond(should_crop, apply_crop, lambda: images)
elif pad_to_aspect_ratio:
shape = tf.shape(images)
height, width = shape[-3], shape[-2]
Expand Down
27 changes: 21 additions & 6 deletions keras/src/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,27 @@ def resize(
shape = images.shape
height, width = shape[-2], shape[-1]
target_height, target_width = size
crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

# Add epsilon to prevent division by zero
epsilon = 1e-6
source_aspect_ratio = float(width) / (float(height) + epsilon)
target_aspect_ratio = float(target_width) / (float(target_height) + epsilon)

# Only crop if aspect ratios differ (with epsilon tolerance)
aspect_ratio_diff = abs(source_aspect_ratio - target_aspect_ratio)
if aspect_ratio_diff > epsilon:
crop_height = int(float(width * target_height) / (target_width + epsilon))
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / (target_height + epsilon))
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
else:
# Skip cropping when aspect ratios match
crop_box_hstart = 0
crop_box_wstart = 0
crop_height = height
crop_width = width
images = images[
:,
:,
Expand Down
33 changes: 26 additions & 7 deletions keras/src/layers/preprocessing/image_preprocessing/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,29 @@ def __init__(
self.width_axis = -2

def transform_images(self, images, transformation=None, training=True):
# Compute effective crop flag:
# only crop if aspect ratios differ and flag is True
input_height, input_width = transformation
epsilon = self.backend.epsilon()
source_aspect_ratio = input_width / (input_height + epsilon)
target_aspect_ratio = self.width / (self.height + epsilon)

# Use a small epsilon for floating-point comparison
aspect_ratios_match = (
abs(source_aspect_ratio - target_aspect_ratio) < 1e-6
)
effective_crop_to_aspect_ratio = (
self.crop_to_aspect_ratio and not aspect_ratios_match
)

size = (self.height, self.width)
resized = self.backend.image.resize(
images,
size=size,
interpolation=self.interpolation,
antialias=self.antialias,
data_format=self.data_format,
crop_to_aspect_ratio=self.crop_to_aspect_ratio,
crop_to_aspect_ratio=effective_crop_to_aspect_ratio,
pad_to_aspect_ratio=self.pad_to_aspect_ratio,
fill_mode=self.fill_mode,
fill_value=self.fill_value,
Expand Down Expand Up @@ -233,17 +248,21 @@ def _transform_boxes_crop_to_aspect_ratio(
):
"""Transforms bounding boxes for cropping to aspect ratio."""
ops = self.backend
source_aspect_ratio = input_width / input_height
target_aspect_ratio = self.width / self.height
# Add epsilon to prevent division by zero
epsilon = ops.cast(ops.epsilon(), dtype=boxes.dtype)
source_aspect_ratio = input_width / (input_height + epsilon)
target_aspect_ratio = ops.cast(
self.width / (self.height + epsilon), dtype=boxes.dtype
)
new_width = ops.numpy.where(
source_aspect_ratio > target_aspect_ratio,
self.height * source_aspect_ratio,
self.width,
ops.cast(self.height, dtype=boxes.dtype) * source_aspect_ratio,
ops.cast(self.width, dtype=boxes.dtype),
)
new_height = ops.numpy.where(
source_aspect_ratio > target_aspect_ratio,
self.height,
self.width / source_aspect_ratio,
ops.cast(self.height, dtype=boxes.dtype),
ops.cast(self.width, dtype=boxes.dtype) / source_aspect_ratio,
)
scale_x = new_width / input_width
scale_y = new_height / input_height
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,31 @@ def test_crop_to_aspect_ratio(self, data_format):
ref_out = ref_out.transpose(0, 3, 1, 2)
self.assertAllClose(ref_out, out)

@parameterized.parameters([("channels_first",), ("channels_last",)])
def test_crop_to_aspect_ratio_no_op_when_aspects_match(self, data_format):
# Test that crop_to_aspect_ratio=True behaves identically to False
# when source and target aspect ratios match (no cropping should occur).
img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype("float32")
if data_format == "channels_first":
img = img.transpose(0, 3, 1, 2)
out_false = layers.Resizing(
height=2,
width=2,
interpolation="nearest",
data_format=data_format,
crop_to_aspect_ratio=False,
)(img)
out_true = layers.Resizing(
height=2,
width=2,
interpolation="nearest",
data_format=data_format,
crop_to_aspect_ratio=True,
)(img)
# Outputs should be identical when aspect ratios match
# (4:4 -> 2:2, both 1:1).
self.assertAllClose(out_false, out_true)

@parameterized.parameters([("channels_first",), ("channels_last",)])
def test_unbatched_image(self, data_format):
img = np.reshape(np.arange(0, 16), (4, 4, 1)).astype("float32")
Expand Down
Loading