Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.nn import adaptive_avg_pool
from keras.src.backend.jax.nn import adaptive_max_pool
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
365 changes: 365 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,3 +1464,368 @@ def _pair(x):
# ---- reshape -> (N, C*kH*kW, L) ----
_, CKK, oH, oW = patches.shape
return patches.reshape(N, CKK, oH * oW)


def get_static_window_sizes(input_dim, output_dim):
"""Calculate small and big window sizes for adaptive pooling."""
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
return small_window, big_window


def compute_static_gather_indices(input_dim, output_size, big_window):
"""Compute gather indices for Two-Pool Gather method."""
window_starts = jnp.floor(
(jnp.arange(output_size) * input_dim) / output_size
).astype(jnp.int32)

window_ends = jnp.ceil(
(jnp.arange(1, output_size + 1) * input_dim) / output_size
).astype(jnp.int32)

window_sizes = window_ends - window_starts
is_big_window = window_sizes == big_window

small_window = big_window - 1
small_pool_len = input_dim - small_window + 1

small_indices = window_starts
big_indices = window_starts + small_pool_len

gather_indices = jnp.where(is_big_window, big_indices, small_indices)
return gather_indices.astype(jnp.int32)


# ---------- 1D Adaptive Pooling ----------
def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid"
)
small_pool_l = small_pool_l / small_l

big_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid"
)
big_pool_l = big_pool_l / big_l

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid"
)
big_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid"
)

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


# ---------- 2D Adaptive Pooling ----------
def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


# ---------- 3D Adaptive Pooling ----------
def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_d = small_pool_d / small_d

big_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_d = big_pool_d / big_d

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs,
-jnp.inf,
lax.max,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_d = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


# ---------- Dispatcher ----------
def adaptive_avg_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive average pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_avg_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_avg_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_avg_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_avg_pool supports 1D, 2D, or 3D inputs only."
)


def adaptive_max_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive max pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_max_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_max_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_max_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_max_pool supports 1D, 2D, or 3D inputs only."
)
16 changes: 16 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,3 +1237,19 @@ def _pair(x):

# ---- reshape -> (N, C*kH*kW, L) ----
return patches.reshape(N, C * k[0] * k[1], -1)


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)
Loading
Loading