From 6a9aa5f0e93f2f8435ad28f9c7ab5e42f9254bb2 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 2 Nov 2025 21:20:23 -0500 Subject: [PATCH 1/4] support pydataset in `adapt` for norm layers --- .../src/layers/preprocessing/normalization.py | 20 +++++++++++++++ .../preprocessing/normalization_test.py | 25 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py index 8ea0d439b31b..1ad6ddf360cb 100644 --- a/keras/src/layers/preprocessing/normalization.py +++ b/keras/src/layers/preprocessing/normalization.py @@ -6,7 +6,9 @@ from keras.src import ops from keras.src.api_export import keras_export from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.trainers.data_adapters import get_data_adapter from keras.src.utils.module_utils import tensorflow as tf +from keras.utils import PyDataset @keras_export("keras.layers.Normalization") @@ -229,6 +231,24 @@ def adapt(self, data): # Batch dataset if it isn't batched data = data.batch(128) input_shape = tuple(data.element_spec.shape) + elif isinstance(data, PyDataset): + # as PyDatasets returns tuples of input/annotation pairs + adapter = get_data_adapter(data) + tf_dataset = adapter.get_tf_dataset() + if len(tf_dataset.element_spec) == 1: + # just x + data = tf_dataset.map(lambda x: x) + input_shape = data.element_spec.shape + elif len(tf_dataset.element_spec) == 2: + # (x, y) pairs + data = tf_dataset.map(lambda x, y: x) + input_shape = data.element_spec.shape + elif len(tf_dataset.element_spec) == 3: + # (x, y, sample_weight) tuples + data = tf_dataset.map(lambda x, y, z: x) + input_shape = data.element_spec.shape + else: + raise NotImplementedError(f"Unsupported data type: {type(data)}") if not self.built: self.build(input_shape) diff --git a/keras/src/layers/preprocessing/normalization_test.py b/keras/src/layers/preprocessing/normalization_test.py index 70dea3787002..09a6c5a6ab18 100644 --- a/keras/src/layers/preprocessing/normalization_test.py +++ b/keras/src/layers/preprocessing/normalization_test.py @@ -169,3 +169,28 @@ def test_normalization_with_scalar_mean_var(self): input_data = np.array([[1, 2, 3]], dtype="float32") layer = layers.Normalization(mean=3.0, variance=2.0) layer(input_data) + + @parameterized.parameters([("x",), ("x_and_y",), ("x_y_and_weights")]) + def test_adapt_pydataset_compat(self, pydataset_type): + import keras + + class CustomDataset(keras.utils.PyDataset): + def __len__(self): + return 100 + + def __getitem__(self, idx): + x = np.random.rand(32, 32, 3) + y = np.random.randint(0, 10, size=(1,)) + weights = np.random.randint(0, 10, size=(1,)) + if pydataset_type == "x": + return x + elif pydataset_type == "x_and_y": + return x, y + elif pydataset_type == "x_y_and_weights": + return x, y, weights + else: + raise NotImplementedError(pydataset_type) + + normalizer = keras.layers.Normalization() + normalizer.adapt(CustomDataset()) + self.assertTrue(normalizer.built) From 944082fa313380036d90e80c4615c68b8eef25e8 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 2 Nov 2025 21:25:35 -0500 Subject: [PATCH 2/4] address gemini comments around type error and duplication --- keras/src/layers/preprocessing/normalization.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py index 1ad6ddf360cb..cfedf49391e3 100644 --- a/keras/src/layers/preprocessing/normalization.py +++ b/keras/src/layers/preprocessing/normalization.py @@ -238,17 +238,19 @@ def adapt(self, data): if len(tf_dataset.element_spec) == 1: # just x data = tf_dataset.map(lambda x: x) - input_shape = data.element_spec.shape elif len(tf_dataset.element_spec) == 2: # (x, y) pairs data = tf_dataset.map(lambda x, y: x) - input_shape = data.element_spec.shape elif len(tf_dataset.element_spec) == 3: # (x, y, sample_weight) tuples data = tf_dataset.map(lambda x, y, z: x) - input_shape = data.element_spec.shape + input_shape = data.element_spec.shape else: - raise NotImplementedError(f"Unsupported data type: {type(data)}") + raise TypeError( + f"Unsupported data type: {type(data)}. `adapt` supports " + f"`np.ndarray`, backend tensors, `tf.data.Dataset`, and " + f"`keras.utils.PyDataset`." + ) if not self.built: self.build(input_shape) From ac460d9b43d23cfb00cb896a76bfc1dc2c8ec1c4 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 2 Nov 2025 21:38:31 -0500 Subject: [PATCH 3/4] add some slightly more robust checks --- keras/src/layers/preprocessing/normalization_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/keras/src/layers/preprocessing/normalization_test.py b/keras/src/layers/preprocessing/normalization_test.py index 09a6c5a6ab18..e5cf9c48ae8a 100644 --- a/keras/src/layers/preprocessing/normalization_test.py +++ b/keras/src/layers/preprocessing/normalization_test.py @@ -194,3 +194,10 @@ def __getitem__(self, idx): normalizer = keras.layers.Normalization() normalizer.adapt(CustomDataset()) self.assertTrue(normalizer.built) + self.assertIsNotNone(normalizer.mean) + self.assertIsNotNone(normalizer.variance) + self.assertEqual(normalizer.mean.shape[-1], 3) + self.assertEqual(normalizer.variance.shape[-1], 3) + sample_input = np.random.rand(1, 32, 32, 3) + output = normalizer(sample_input) + self.assertEqual(output.shape, (1, 32, 32, 3)) From 87db5b335512338a30b763bfa5e43a5c560151b0 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 9 Nov 2025 11:55:41 -0500 Subject: [PATCH 4/4] simplify logic for pydataset support --- keras/src/layers/preprocessing/normalization.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py index cfedf49391e3..e39778973c0e 100644 --- a/keras/src/layers/preprocessing/normalization.py +++ b/keras/src/layers/preprocessing/normalization.py @@ -232,18 +232,11 @@ def adapt(self, data): data = data.batch(128) input_shape = tuple(data.element_spec.shape) elif isinstance(data, PyDataset): - # as PyDatasets returns tuples of input/annotation pairs adapter = get_data_adapter(data) tf_dataset = adapter.get_tf_dataset() - if len(tf_dataset.element_spec) == 1: - # just x - data = tf_dataset.map(lambda x: x) - elif len(tf_dataset.element_spec) == 2: - # (x, y) pairs - data = tf_dataset.map(lambda x, y: x) - elif len(tf_dataset.element_spec) == 3: - # (x, y, sample_weight) tuples - data = tf_dataset.map(lambda x, y, z: x) + # args will either be (samples,), (samples, labels) pairs, + # or (samples, labels, sample_weights) tuples + data = tf_dataset.map(lambda *args: args[0]) input_shape = data.element_spec.shape else: raise TypeError(