From 97285b4ed2bc71e26f259bb256388e98b992cfc6 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Tue, 28 Oct 2025 01:12:00 +0530 Subject: [PATCH 01/10] Fix: added backend validation for dataset adapters across backends --- Issues/Keras_with_pytorch_backend.py | 22 +++++++++++++++++++ .../data_adapters/py_dataset_adapter.py | 18 +++++++++++++-- .../data_adapters/tf_dataset_adapter.py | 19 ++++++++++++---- .../torch_data_loader_adapter.py | 11 ++++++++++ 4 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 Issues/Keras_with_pytorch_backend.py diff --git a/Issues/Keras_with_pytorch_backend.py b/Issues/Keras_with_pytorch_backend.py new file mode 100644 index 000000000000..09cb8cf3c79b --- /dev/null +++ b/Issues/Keras_with_pytorch_backend.py @@ -0,0 +1,22 @@ +import keras +from keras import ops + +keras.config.set_backend("torch") + +def tensor_operations_example(): + # Create tensors + x = ops.array([[1, 2, 3], [4, 5, 6]]) + y = ops.array([[10, 20, 30], [40, 50, 60]]) + + # Perform elementwise operations + z_add = ops.add(x, y) + z_mul = ops.multiply(x, y) + z_mean = ops.mean(z_mul) + z_norm = ops.sqrt(ops.sum(ops.square(x))) + + print("x + y =\n", z_add) + print("x * y =\n", z_mul) + print("Mean(x * y) =", z_mean) + print("L2 norm of x =", z_norm) + +tensor_operations_example() diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 18865af026cf..04b2400111ab 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -12,7 +12,7 @@ from keras.src.api_export import keras_export from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter - +from keras import backend @keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"]) class PyDataset: @@ -94,7 +94,21 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): self._workers = workers self._use_multiprocessing = use_multiprocessing self._max_queue_size = max_queue_size - + backend_name = backend.backend() + if backend_name not in ("torch", "jax", "tensorflow"): + raise ValueError( + f"PyDataset is only supported for PyTorch, JAX, or TensorFlow backends. " + f"Received unsupported backend: '{backend_name}'." + ) + # Optionally warn if using TF (since tf.data.Dataset is better) + if backend_name == "tensorflow": + import warnings + warnings.warn( + "You are using PyDataset with the TensorFlow backend. " + "Consider using `tf.data.Dataset` for better performance.", + stacklevel=2, + ) + def _warn_if_super_not_called(self): warn = False if not hasattr(self, "_workers"): diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index 492deb764c3e..fd9357a970b7 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -18,14 +18,25 @@ def __init__(self, dataset, class_weight=None, distribution=None): instance. """ from keras.src.utils.module_utils import tensorflow as tf + import keras + from keras.src.utils.module_utils import tensorflow as tf - if not isinstance( - dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) - ): + # --- ✅ Backend compatibility check --- + backend = keras.backend.backend() + if backend != "tensorflow": raise ValueError( - "Expected argument `dataset` to be a tf.data.Dataset. " + f"Incompatible backend '{backend}' for TFDatasetAdapter. " + "This adapter only supports the TensorFlow backend." + ) + + # --- ✅ Dataset type validation --- + if not isinstance(dataset, (tf.data.Dataset, tf.distribute.DistributedDataset)): + raise ValueError( + "Expected argument `dataset` to be a tf.data.Dataset or " + "tf.distribute.DistributedDataset. " f"Received: {dataset}" ) + if class_weight is not None: dataset = dataset.map( make_class_weight_map_fn(class_weight) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index f0b2f524f4dd..f1a4e8d415b4 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -12,6 +12,17 @@ class TorchDataLoaderAdapter(DataAdapter): def __init__(self, dataloader): import torch + import keras + + # --- ✅ Backend compatibility check --- + backend = keras.backend.backend() + if backend != "torch": + raise ValueError( + f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. " + "This adapter only supports the PyTorch backend. " + "If you are using TensorFlow or JAX, please use the " + "corresponding DatasetAdapter instead." + ) if not isinstance(dataloader, torch.utils.data.DataLoader): raise ValueError( From 8cd14f288bd9a0dfcbb2980d6ae8de2f5efacb61 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 12:39:21 +0530 Subject: [PATCH 02/10] Fix: shorten ValueError line to satisfy Ruff E501 --- keras/src/trainers/data_adapters/py_dataset_adapter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 04b2400111ab..4f7beaa97c51 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -9,10 +9,11 @@ import numpy as np +from keras import backend from keras.src.api_export import keras_export from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter -from keras import backend + @keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"]) class PyDataset: @@ -97,7 +98,7 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): backend_name = backend.backend() if backend_name not in ("torch", "jax", "tensorflow"): raise ValueError( - f"PyDataset is only supported for PyTorch, JAX, or TensorFlow backends. " + f"PyDataset supports tf,torch,jax backend" f"Received unsupported backend: '{backend_name}'." ) # Optionally warn if using TF (since tf.data.Dataset is better) From 0d0a2241b80461d502bcde13cc6bd6aa4d191c95 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 13:28:49 +0530 Subject: [PATCH 03/10] Fix backend compatibility and clean up old issue file --- Issues/Keras_with_pytorch_backend.py | 22 ------------------- .../data_adapters/tf_dataset_adapter.py | 10 +++++---- .../torch_data_loader_adapter.py | 12 +++++----- 3 files changed, 13 insertions(+), 31 deletions(-) delete mode 100644 Issues/Keras_with_pytorch_backend.py diff --git a/Issues/Keras_with_pytorch_backend.py b/Issues/Keras_with_pytorch_backend.py deleted file mode 100644 index 09cb8cf3c79b..000000000000 --- a/Issues/Keras_with_pytorch_backend.py +++ /dev/null @@ -1,22 +0,0 @@ -import keras -from keras import ops - -keras.config.set_backend("torch") - -def tensor_operations_example(): - # Create tensors - x = ops.array([[1, 2, 3], [4, 5, 6]]) - y = ops.array([[10, 20, 30], [40, 50, 60]]) - - # Perform elementwise operations - z_add = ops.add(x, y) - z_mul = ops.multiply(x, y) - z_mean = ops.mean(z_mul) - z_norm = ops.sqrt(ops.sum(ops.square(x))) - - print("x + y =\n", z_add) - print("x * y =\n", z_mul) - print("Mean(x * y) =", z_mean) - print("L2 norm of x =", z_norm) - -tensor_operations_example() diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index fd9357a970b7..c76f77b319e4 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -17,20 +17,22 @@ def __init__(self, dataset, class_weight=None, distribution=None): shard the input dataset into per worker/process dataset instance. """ - from keras.src.utils.module_utils import tensorflow as tf import keras from keras.src.utils.module_utils import tensorflow as tf # --- ✅ Backend compatibility check --- backend = keras.backend.backend() - if backend != "tensorflow": + if backend not in ("tensorflow","numpy","torch","jax"): raise ValueError( f"Incompatible backend '{backend}' for TFDatasetAdapter. " - "This adapter only supports the TensorFlow backend." + "This adapter only supports the TensorFlow , numpy , torch ," \ + " jax backend." ) # --- ✅ Dataset type validation --- - if not isinstance(dataset, (tf.data.Dataset, tf.distribute.DistributedDataset)): + if not isinstance( + dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) + ): raise ValueError( "Expected argument `dataset` to be a tf.data.Dataset or " "tf.distribute.DistributedDataset. " diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index f1a4e8d415b4..79a9f158dcf1 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -1,26 +1,28 @@ import itertools import numpy as np +import torch +import keras from keras.src import tree from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + class TorchDataLoaderAdapter(DataAdapter): """Adapter that handles `torch.utils.data.DataLoader`.""" def __init__(self, dataloader): - import torch - import keras # --- ✅ Backend compatibility check --- backend = keras.backend.backend() - if backend != "torch": + if backend not in ("torch","tensorflow"): raise ValueError( f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. " - "This adapter only supports the PyTorch backend. " - "If you are using TensorFlow or JAX, please use the " + "This adapter only supports the PyTorch, tensorflow backend. " + "If you are using TensorFlow, please use the " "corresponding DatasetAdapter instead." ) From f842a5f620766a03a4c258e4f7c399eb690e6880 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 13:33:19 +0530 Subject: [PATCH 04/10] Auto-format imports and code via Ruff --- keras/src/trainers/data_adapters/py_dataset_adapter.py | 3 ++- keras/src/trainers/data_adapters/tf_dataset_adapter.py | 6 +++--- .../trainers/data_adapters/torch_data_loader_adapter.py | 7 ++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 4f7beaa97c51..bfdaa91f797f 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -104,12 +104,13 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): # Optionally warn if using TF (since tf.data.Dataset is better) if backend_name == "tensorflow": import warnings + warnings.warn( "You are using PyDataset with the TensorFlow backend. " "Consider using `tf.data.Dataset` for better performance.", stacklevel=2, ) - + def _warn_if_super_not_called(self): warn = False if not hasattr(self, "_workers"): diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index c76f77b319e4..aed249812dd6 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -22,17 +22,17 @@ def __init__(self, dataset, class_weight=None, distribution=None): # --- ✅ Backend compatibility check --- backend = keras.backend.backend() - if backend not in ("tensorflow","numpy","torch","jax"): + if backend not in ("tensorflow", "numpy", "torch", "jax"): raise ValueError( f"Incompatible backend '{backend}' for TFDatasetAdapter. " - "This adapter only supports the TensorFlow , numpy , torch ," \ + "This adapter only supports the TensorFlow , numpy , torch ," " jax backend." ) # --- ✅ Dataset type validation --- if not isinstance( dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) - ): + ): raise ValueError( "Expected argument `dataset` to be a tf.data.Dataset or " "tf.distribute.DistributedDataset. " diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 79a9f158dcf1..9bf8526e2589 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -2,23 +2,20 @@ import numpy as np import torch -import keras +import keras from keras.src import tree from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter - - class TorchDataLoaderAdapter(DataAdapter): """Adapter that handles `torch.utils.data.DataLoader`.""" def __init__(self, dataloader): - # --- ✅ Backend compatibility check --- backend = keras.backend.backend() - if backend not in ("torch","tensorflow"): + if backend not in ("torch", "tensorflow"): raise ValueError( f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. " "This adapter only supports the PyTorch, tensorflow backend. " From 17fa6959356364bfa5e623b2eee9fac0f70541a8 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 13:43:20 +0530 Subject: [PATCH 05/10] Fixed torch import issue --- .../src/trainers/data_adapters/torch_data_loader_adapter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 9bf8526e2589..6586a1c41d8b 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -1,8 +1,6 @@ import itertools import numpy as np -import torch - import keras from keras.src import tree from keras.src.trainers.data_adapters import data_adapter_utils @@ -14,6 +12,9 @@ class TorchDataLoaderAdapter(DataAdapter): def __init__(self, dataloader): # --- ✅ Backend compatibility check --- + import keras + import torch + backend = keras.backend.backend() if backend not in ("torch", "tensorflow"): raise ValueError( From 69652ded4a675a63b5c9d7fa2f4bc89fe6a929ad Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 13:44:20 +0530 Subject: [PATCH 06/10] Fixed torch import issue --- .../src/trainers/data_adapters/torch_data_loader_adapter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 6586a1c41d8b..9512aec5c47b 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -1,7 +1,7 @@ import itertools import numpy as np -import keras + from keras.src import tree from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter @@ -12,9 +12,10 @@ class TorchDataLoaderAdapter(DataAdapter): def __init__(self, dataloader): # --- ✅ Backend compatibility check --- - import keras import torch + import keras + backend = keras.backend.backend() if backend not in ("torch", "tensorflow"): raise ValueError( From 70c8577eff777966d6f3516aea13ed18ed52592f Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 14:42:43 +0530 Subject: [PATCH 07/10] Fixed numpy, jax backened issue in pytorch dataset and tf dataset --- .../trainers/data_adapters/torch_data_loader_adapter.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 9512aec5c47b..5c12923a773b 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -15,14 +15,13 @@ def __init__(self, dataloader): import torch import keras - + backend = keras.backend.backend() - if backend not in ("torch", "tensorflow"): + if backend not in ("torch", "tensorflow","numpy","jax"): raise ValueError( f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. " - "This adapter only supports the PyTorch, tensorflow backend. " - "If you are using TensorFlow, please use the " - "corresponding DatasetAdapter instead." + "This adapter only supports the PyTorch, tensorflow, jax, numpy" \ + " backend. " ) if not isinstance(dataloader, torch.utils.data.DataLoader): From 2a5f2d44ea0024ae291f79dfab82fc607b04b1a3 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 14:48:32 +0530 Subject: [PATCH 08/10] Update API stubs after running api-gen --- keras/src/trainers/data_adapters/torch_data_loader_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 5c12923a773b..2600c9b58308 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -17,10 +17,10 @@ def __init__(self, dataloader): import keras backend = keras.backend.backend() - if backend not in ("torch", "tensorflow","numpy","jax"): + if backend not in ("torch", "tensorflow", "numpy", "jax"): raise ValueError( f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. " - "This adapter only supports the PyTorch, tensorflow, jax, numpy" \ + "This adapter only supports the PyTorch, tensorflow, jax, numpy" " backend. " ) From f374ea2bf29feedcdae8b2383e8e99d0e5458733 Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 14:53:08 +0530 Subject: [PATCH 09/10] Update API stubs after running api-gen, corrected numpy backend issue in pydataset --- keras/src/trainers/data_adapters/py_dataset_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index bfdaa91f797f..17fc2d63b6c2 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -96,9 +96,9 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): self._use_multiprocessing = use_multiprocessing self._max_queue_size = max_queue_size backend_name = backend.backend() - if backend_name not in ("torch", "jax", "tensorflow"): + if backend_name not in ("torch", "jax", "tensorflow","numpy"): raise ValueError( - f"PyDataset supports tf,torch,jax backend" + f"PyDataset supports tf, torch, jax, numpy backend" f"Received unsupported backend: '{backend_name}'." ) # Optionally warn if using TF (since tf.data.Dataset is better) From e19485d6841abd6cab12f73ae94c8e41520cda1a Mon Sep 17 00:00:00 2001 From: Shekar77 Date: Thu, 30 Oct 2025 14:53:49 +0530 Subject: [PATCH 10/10] Update API stubs after running api-gen, corrected numpy backend issue in pydataset --- keras/src/trainers/data_adapters/py_dataset_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 17fc2d63b6c2..1599cf61f1b8 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -96,7 +96,7 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): self._use_multiprocessing = use_multiprocessing self._max_queue_size = max_queue_size backend_name = backend.backend() - if backend_name not in ("torch", "jax", "tensorflow","numpy"): + if backend_name not in ("torch", "jax", "tensorflow", "numpy"): raise ValueError( f"PyDataset supports tf, torch, jax, numpy backend" f"Received unsupported backend: '{backend_name}'."