From 06bb3bb77d8ddcda55bd0753a6216d0a1689f5f3 Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 28 Oct 2025 10:03:29 +0530 Subject: [PATCH 1/6] Adding tensor layout for TP autosharding --- keras/src/backend/jax/core.py | 58 ++++++- keras/src/backend/jax/core_test.py | 78 +++++++++ .../tensor_parallel/tensor_layout.py | 43 +++++ .../tensor_parallel/tensor_layout_test.py | 163 ++++++++++++++++++ 4 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 keras/src/distribution/tensor_parallel/tensor_layout.py create mode 100644 keras/src/distribution/tensor_parallel/tensor_layout_test.py diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 7dc5a98fb8d5..aee30a3deadd 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -1,5 +1,6 @@ import jax import jax.experimental.sparse as jax_sparse +import jax.lax as lax import jax.numpy as jnp import ml_dtypes import numpy as np @@ -529,6 +530,61 @@ def remat(f): return jax.checkpoint(f) +def all_reduce(x, op="sum", axis_name="model"): + """ + Performs an **all-reduce** operation across all replicas in the specified + distribution axis. + + The all-reduce operation computes a reduction (like sum, mean, or product) + of the input tensor `x` across all devices/replicas in the `axis_name` + group, and then broadcasts the result back to all participating devices. + + Args: + x: The tensor to reduce. + op: The reduction operation to perform. Common options include "sum", + "mean", or "product". Defaults to "sum". + axis_name: The name of the distribution axis (e.g., "model", + "data") over which to perform the reduction. Defaults to "model". + + Returns: + The result of the all-reduce operation, with the same shape as the + input `x`. + """ + if op == "sum": + return lax.psum(x, axis_name=axis_name) + elif op == "mean": + return lax.pmean(x, axis_name=axis_name) + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." + ) + + +def all_gather(x, axis, axis_name="model"): + """ + Performs an all-gather operation across all replicas in the specified + distribution axis. + + The all-gather operation collects the input tensor `x` from all devices + in the `axis_name` group and concatenates them along the specified `axis`. + This is often used in tensor parallelism to combine parts of a tensor + distributed across devices. + + Args: + x: The tensor to gather. + axis: The dimension along which to concatenate the gathered tensors. + axis_name: The name of the distribution axis (e.g., "model", + "data") over which to perform the gather. + Defaults to "model". + + Returns: + The gathered tensor, which will have a larger size along `axis` + dimension. + """ + return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) + + class name_scope(base_name_scope): def __init__(self, name, **kwargs): super().__init__(name, **kwargs) @@ -571,4 +627,4 @@ def device_scope(device_name): ) else: jax_device = device_name - return jax.default_device(jax_device) + return jax.default_device(jax_device) \ No newline at end of file diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 792cf25e67f0..79eecad18063 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -1,3 +1,4 @@ +import functools import os import jax @@ -9,6 +10,8 @@ from keras.src import backend from keras.src import testing from keras.src.backend.config import is_nnx_enabled +from keras.src.backend.jax.core import all_gather +from keras.src.backend.jax.core import all_reduce if is_nnx_enabled(): from flax import nnx @@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self): state = jax.tree.map(lambda x: x + 1, state) variable2 = nnx.merge(graphdef, state) self.assertEqual(variable2._value, variable2.value) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JAX backend specific test for collective operations.", +) +@pytest.mark.skipif( + jax.local_device_count() < 2, + reason="Requires multiple local devices for testing.", +) +class JaxCollectiveOpsTest(testing.TestCase): + def test_all_reduce_sum(self): + """Tests the all_reduce operation with the 'sum' reduction.""" + num_devices = jax.local_device_count() + local_value = 10.0 + + local_inputs = jax.numpy.array([local_value] * num_devices) + + @functools.partial( + jax.pmap, axis_name="all", devices=jax.devices("cpu") + ) + def reduce_sum_fn(x): + return all_reduce(x, op="sum", axis_name="all") + + result = reduce_sum_fn(local_inputs) + expected_sum = local_value * num_devices + + self.assertTrue(np.allclose(result, expected_sum)) + self.assertEqual(result.shape, (num_devices,)) + + def test_all_reduce_mean(self): + """Tests the all_reduce operation with the 'mean' reduction.""" + num_devices = jax.local_device_count() + local_value = 10.0 + + local_inputs = jax.numpy.array([local_value] * num_devices) + + @functools.partial( + jax.pmap, axis_name="all", devices=jax.devices("cpu") + ) + def reduce_mean_fn(x): + return all_reduce(x, op="mean", axis_name="all") + + result = reduce_mean_fn(local_inputs) + expected_mean = local_value + + self.assertTrue(np.allclose(result, expected_mean)) + self.assertEqual(result.shape, (num_devices,)) + + def test_all_gather(self): + """Tests the all_gather operation.""" + num_devices = jax.local_device_count() + local_data = np.arange(5) + + local_inputs = jax.numpy.stack( + [local_data + (i * 5) for i in range(num_devices)] + ) + + @functools.partial( + jax.pmap, axis_name="all", devices=jax.devices("cpu") + ) + def gather_fn(x): + return all_gather(x, axis=0, axis_name="all") + + result_array_on_devices = gather_fn(local_inputs) + + expected_shape = (num_devices, num_devices * local_data.shape[0]) + self.assertEqual(result_array_on_devices.shape, expected_shape) + + expected_gathered_data = np.arange(num_devices * local_data.shape[0]) + + for i in range(num_devices): + self.assertTrue( + np.allclose(result_array_on_devices[i], expected_gathered_data) + ) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py new file mode 100644 index 000000000000..ff6b4eff920b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -0,0 +1,43 @@ +import collections + +from keras.src import ops + + +def split_tensor_for_parallelism(tensor, index, device_count, dim): + """Calculates a slice of a tensor along a specified dimension for a + given index. + + This utility is used in tensor parallelism API to distribute a + tensor across multiple devices. + + Args: + tensor: The full tensor to be sharded. + index: The index of the device/shard to return (e.g., 0, 1, 2...). + device_count: The total number of parallel devices or splits. + dim: The dimension along which to split the tensor. If -1, the + last dimension is used. + + Returns: + A tensor slice corresponding to the given `index`. + """ + if dim == -1: + static_shape = getattr(tensor, "shape", None) + if static_shape is not None: + rank = len(static_shape) + else: + rank = None + + if rank is not None: + split_dim = rank - 1 + else: + split_dim = ops.ndim(tensor) - 1 + else: + split_dim = dim + + splits = ops.array_split( + tensor, indices_or_sections=device_count, axis=split_dim + ) + return splits[index] + + +LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"]) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py new file mode 100644 index 000000000000..d30f6a1b4495 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -0,0 +1,163 @@ +from keras.src import ops +from keras.src import testing +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import ( + split_tensor_for_parallelism, +) + + +class LayoutTest(testing.TestCase): + """Test suite for tensor layout actions and mappings.""" + + def test_split_with_even_division(self): + """Tests splitting a tensor that divides evenly among workers.""" + device_count = 4 + dim = 0 + tensor = ops.reshape(ops.arange(16, dtype="float32"), (8, 2)) + + expected_shard_0 = ops.array([[0.0, 1.0], [2.0, 3.0]]) + expected_shard_2 = ops.array([[8.0, 9.0], [10.0, 11.0]]) + + shard_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=dim + ) + shard_2 = split_tensor_for_parallelism( + tensor, index=2, device_count=device_count, dim=dim + ) + + self.assertAllClose(shard_0, expected_shard_0) + self.assertAllClose(shard_2, expected_shard_2) + self.assertEqual(shard_0.shape, (2, 2)) + + def test_split_with_uneven_division(self): + """Tests splitting tensor where remainder is distributed correctly.""" + device_count = 3 + dim = 0 + tensor = ops.reshape(ops.arange(10, dtype="float32"), (10, 1)) + + shard_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=dim + ) + self.assertEqual(shard_0.shape, (4, 1)) + self.assertAllClose(shard_0, ops.array([[0.0], [1.0], [2.0], [3.0]])) + + shard_1 = split_tensor_for_parallelism( + tensor, index=1, device_count=device_count, dim=dim + ) + self.assertEqual(shard_1.shape, (3, 1)) + self.assertAllClose(shard_1, ops.array([[4.0], [5.0], [6.0]])) + + shard_2 = split_tensor_for_parallelism( + tensor, index=2, device_count=device_count, dim=dim + ) + self.assertEqual(shard_2.shape, (3, 1)) + self.assertAllClose(shard_2, ops.array([[7.0], [8.0], [9.0]])) + + def test_split_and_undo_cycle_even_removed(self): + """ + Confirms that the original tensor can be reconstructed. + """ + device_count = 2 + dim = 0 + original_tensor = ops.reshape(ops.arange(12, dtype="float32"), (6, 2)) + + shards = [ + split_tensor_for_parallelism( + original_tensor, index=i, device_count=device_count, dim=dim + ) + for i in range(device_count) + ] + + reconstructed_tensor = ops.concatenate(shards, axis=dim) + + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_and_undo_cycle_uneven_removed(self): + """ + Confirms that original tensor can be reconstructed with uneven split. + """ + device_count = 4 + dim = 0 + original_tensor = ops.reshape(ops.arange(22, dtype="float32"), (11, 2)) + + shards = [ + split_tensor_for_parallelism( + original_tensor, index=i, device_count=device_count, dim=dim + ) + for i in range(device_count) + ] + + self.assertEqual(shards[0].shape, (3, 2)) + self.assertEqual(shards[1].shape, (3, 2)) + self.assertEqual(shards[2].shape, (3, 2)) + self.assertEqual(shards[3].shape, (2, 2)) + + reconstructed_tensor = ops.concatenate(shards, axis=dim) + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_last_dimension(self): + """Tests splitting on the last dimension using dim=-1.""" + device_count = 3 + dim = -1 + original_tensor = ops.reshape( + ops.arange(30, dtype="float32"), (2, 5, 3) + ) + + shards = [ + split_tensor_for_parallelism( + original_tensor, index=i, device_count=device_count, dim=dim + ) + for i in range(device_count) + ] + + self.assertEqual(shards[0].shape, (2, 5, 1)) + self.assertEqual(shards[1].shape, (2, 5, 1)) + self.assertEqual(shards[2].shape, (2, 5, 1)) + + def test_split_with_sharding_type_hint(self): + """Tests using 'row' and 'column' sharding hints for 2D tensors.""" + device_count = 2 + tensor = ops.reshape(ops.arange(16, dtype="float32"), (4, 4)) + + row_dim = 0 + shard_row_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=row_dim + ) + self.assertAllClose(shard_row_0, tensor[:2, :]) + + col_dim = 1 + shard_col_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=col_dim + ) + self.assertAllClose(shard_col_0, tensor[:, :2]) + + def test_layout_map_namedtuple_behavior(self): + """Tests basic behavior of the LayoutMap namedtuple.""" + + def rule_kernel(tensor, index): + return split_tensor_for_parallelism( + tensor, index=index, device_count=2, dim=0 + ) + + def rule_output(tensor, index): + return split_tensor_for_parallelism( + tensor, index=index, device_count=2, dim=-1 + ) + + state_rules = {"kernel": rule_kernel} + output_rules = {"output": rule_output} + + layout_map = LayoutMap( + state_rules=state_rules, output_rules=output_rules + ) + + self.assertIs(layout_map.state_rules, state_rules) + self.assertIs(layout_map.output_rules, output_rules) + + self.assertIs(layout_map[0], state_rules) + self.assertIs(layout_map[1], output_rules) + + with self.assertRaises(AttributeError): + layout_map.state_rules = {} + + self.assertTrue(callable(layout_map.state_rules["kernel"])) \ No newline at end of file From 41f80258302f813be32ef3b947203ba0c4f777cf Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 28 Oct 2025 10:08:30 +0530 Subject: [PATCH 2/6] formatting files --- keras/src/backend/jax/core.py | 2 +- keras/src/backend/jax/core_test.py | 2 +- keras/src/distribution/tensor_parallel/tensor_layout.py | 2 +- keras/src/distribution/tensor_parallel/tensor_layout_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index aee30a3deadd..f55fd23e502d 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -627,4 +627,4 @@ def device_scope(device_name): ) else: jax_device = device_name - return jax.default_device(jax_device) \ No newline at end of file + return jax.default_device(jax_device) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 79eecad18063..2e7c312aa33e 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -143,4 +143,4 @@ def gather_fn(x): for i in range(num_devices): self.assertTrue( np.allclose(result_array_on_devices[i], expected_gathered_data) - ) \ No newline at end of file + ) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py index ff6b4eff920b..00f766434b34 100644 --- a/keras/src/distribution/tensor_parallel/tensor_layout.py +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -40,4 +40,4 @@ def split_tensor_for_parallelism(tensor, index, device_count, dim): return splits[index] -LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"]) \ No newline at end of file +LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"]) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py index d30f6a1b4495..7a8f3b61d8e4 100644 --- a/keras/src/distribution/tensor_parallel/tensor_layout_test.py +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -160,4 +160,4 @@ def rule_output(tensor, index): with self.assertRaises(AttributeError): layout_map.state_rules = {} - self.assertTrue(callable(layout_map.state_rules["kernel"])) \ No newline at end of file + self.assertTrue(callable(layout_map.state_rules["kernel"])) From e74eab2a8a68b562f4ee65d01dcdba86446a35ee Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 28 Oct 2025 10:41:52 +0530 Subject: [PATCH 3/6] Updating the docstring Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/backend/jax/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index f55fd23e502d..d8d2db89135b 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -535,14 +535,14 @@ def all_reduce(x, op="sum", axis_name="model"): Performs an **all-reduce** operation across all replicas in the specified distribution axis. - The all-reduce operation computes a reduction (like sum, mean, or product) + The all-reduce operation computes a reduction (like sum or mean) of the input tensor `x` across all devices/replicas in the `axis_name` group, and then broadcasts the result back to all participating devices. Args: x: The tensor to reduce. - op: The reduction operation to perform. Common options include "sum", - "mean", or "product". Defaults to "sum". + op: The reduction operation to perform. Common options include "sum" + and "mean". Defaults to "sum". axis_name: The name of the distribution axis (e.g., "model", "data") over which to perform the reduction. Defaults to "model". From 2cddf39134ad4acdc73deb483a202807bbc89c77 Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 28 Oct 2025 10:53:12 +0530 Subject: [PATCH 4/6] refactoring the code --- .../src/distribution/tensor_parallel/tensor_layout.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py index 00f766434b34..5635d7de2df6 100644 --- a/keras/src/distribution/tensor_parallel/tensor_layout.py +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -21,16 +21,7 @@ def split_tensor_for_parallelism(tensor, index, device_count, dim): A tensor slice corresponding to the given `index`. """ if dim == -1: - static_shape = getattr(tensor, "shape", None) - if static_shape is not None: - rank = len(static_shape) - else: - rank = None - - if rank is not None: - split_dim = rank - 1 - else: - split_dim = ops.ndim(tensor) - 1 + split_dim = ops.ndim(tensor) - 1 else: split_dim = dim From 5365f1483f3932c6586f9a69baef586a67dfc3da Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 6 Nov 2025 13:45:42 +0530 Subject: [PATCH 5/6] fixing test --- .../src/distribution/tensor_parallel/tensor_layout_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py index 7a8f3b61d8e4..9ba09d904b34 100644 --- a/keras/src/distribution/tensor_parallel/tensor_layout_test.py +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -96,9 +96,11 @@ def test_split_and_undo_cycle_uneven_removed(self): self.assertAllClose(original_tensor, reconstructed_tensor) def test_split_last_dimension(self): - """Tests splitting on the last dimension using dim=-1.""" + """Tests splitting on the last dimension.""" device_count = 3 - dim = -1 + # Change dim from -1 to 2 (the explicit index of the last dimension) + # to avoid backend-specific issues with dynamic shape resolution. + dim = 2 original_tensor = ops.reshape( ops.arange(30, dtype="float32"), (2, 5, 3) ) From bc4d09461d0abcc85fb4c705e36fbd306277cd8f Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 6 Nov 2025 13:46:06 +0530 Subject: [PATCH 6/6] fixing test --- keras/src/distribution/tensor_parallel/tensor_layout_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py index 9ba09d904b34..72b21b4912aa 100644 --- a/keras/src/distribution/tensor_parallel/tensor_layout_test.py +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -98,8 +98,6 @@ def test_split_and_undo_cycle_uneven_removed(self): def test_split_last_dimension(self): """Tests splitting on the last dimension.""" device_count = 3 - # Change dim from -1 to 2 (the explicit index of the last dimension) - # to avoid backend-specific issues with dynamic shape resolution. dim = 2 original_tensor = ops.reshape( ops.arange(30, dtype="float32"), (2, 5, 3)