diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py index 8ea0d439b31..e39778973c0 100644 --- a/keras/src/layers/preprocessing/normalization.py +++ b/keras/src/layers/preprocessing/normalization.py @@ -6,7 +6,9 @@ from keras.src import ops from keras.src.api_export import keras_export from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.trainers.data_adapters import get_data_adapter from keras.src.utils.module_utils import tensorflow as tf +from keras.utils import PyDataset @keras_export("keras.layers.Normalization") @@ -229,6 +231,19 @@ def adapt(self, data): # Batch dataset if it isn't batched data = data.batch(128) input_shape = tuple(data.element_spec.shape) + elif isinstance(data, PyDataset): + adapter = get_data_adapter(data) + tf_dataset = adapter.get_tf_dataset() + # args will either be (samples,), (samples, labels) pairs, + # or (samples, labels, sample_weights) tuples + data = tf_dataset.map(lambda *args: args[0]) + input_shape = data.element_spec.shape + else: + raise TypeError( + f"Unsupported data type: {type(data)}. `adapt` supports " + f"`np.ndarray`, backend tensors, `tf.data.Dataset`, and " + f"`keras.utils.PyDataset`." + ) if not self.built: self.build(input_shape) diff --git a/keras/src/layers/preprocessing/normalization_test.py b/keras/src/layers/preprocessing/normalization_test.py index 70dea378700..e5cf9c48ae8 100644 --- a/keras/src/layers/preprocessing/normalization_test.py +++ b/keras/src/layers/preprocessing/normalization_test.py @@ -169,3 +169,35 @@ def test_normalization_with_scalar_mean_var(self): input_data = np.array([[1, 2, 3]], dtype="float32") layer = layers.Normalization(mean=3.0, variance=2.0) layer(input_data) + + @parameterized.parameters([("x",), ("x_and_y",), ("x_y_and_weights")]) + def test_adapt_pydataset_compat(self, pydataset_type): + import keras + + class CustomDataset(keras.utils.PyDataset): + def __len__(self): + return 100 + + def __getitem__(self, idx): + x = np.random.rand(32, 32, 3) + y = np.random.randint(0, 10, size=(1,)) + weights = np.random.randint(0, 10, size=(1,)) + if pydataset_type == "x": + return x + elif pydataset_type == "x_and_y": + return x, y + elif pydataset_type == "x_y_and_weights": + return x, y, weights + else: + raise NotImplementedError(pydataset_type) + + normalizer = keras.layers.Normalization() + normalizer.adapt(CustomDataset()) + self.assertTrue(normalizer.built) + self.assertIsNotNone(normalizer.mean) + self.assertIsNotNone(normalizer.variance) + self.assertEqual(normalizer.mean.shape[-1], 3) + self.assertEqual(normalizer.variance.shape[-1], 3) + sample_input = np.random.rand(1, 32, 32, 3) + output = normalizer(sample_input) + self.assertEqual(output.shape, (1, 32, 32, 3))