From e8f1a5afe4ccfa7e328338b82f190ed76f0be053 Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 7 Oct 2025 10:08:36 +0530 Subject: [PATCH 1/8] Adding parameter sharding and test --- .../parameter_shardimg_test.py | 141 +++++ .../tensor_parallel/parameter_sharding.py | 514 ++++++++++++++++++ 2 files changed, 655 insertions(+) create mode 100644 keras/src/distribution/tensor_parallel/parameter_shardimg_test.py create mode 100644 keras/src/distribution/tensor_parallel/parameter_sharding.py diff --git a/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py b/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py new file mode 100644 index 000000000000..348350299865 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py @@ -0,0 +1,141 @@ +import os + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" +import re + +import numpy as np +import pytest + +import keras +from keras import distribution +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + ShardedWeight, +) +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + make_parameter_sharded_model, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.testing import TestCase + + +@pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test is JAX-specific.", +) +def _create_simple_mlp(): + """Creates a simple, unsharded Keras MLP model for testing.""" + inputs = keras.Input(shape=(16,), name="input") + x = keras.layers.Dense(32, use_bias=True, name="up_proj")(inputs) + x = keras.layers.Activation("relu")(x) + outputs = keras.layers.Dense(8, use_bias=False, name="down_proj")(x) + return keras.Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + +class ParameterShardingTest(TestCase): + def setUp(self): + super().setUp() + import logging + + logging.getLogger().setLevel(logging.ERROR) + + self.world_size = 2 + all_devices = distribution.list_devices() + self.devices = all_devices[: self.world_size] + if len(self.devices) < self.world_size: + self.skipTest( + f"""Not enough devices to run TP test. + Found {len(self.devices)}, need {self.world_size}""" + ) + + self.original_model = _create_simple_mlp() + self.original_model.build(input_shape=(None, 16)) + + self.tp_config = ConfigKeras( + state_rules={ + re.escape("simple_mlp.up_proj.kernel"): SplitKeras( + self.world_size, dim=1 + ), + re.escape("simple_mlp.down_proj.kernel"): SplitKeras( + self.world_size, dim=0 + ), + }, + output_rules={}, + ) + self.input_data = np.random.rand(4, 16).astype("float32") + self.labels = np.random.rand(4, 8).astype("float32") + + def test_model_sharding_creation_and_weight_counts(self): + sharded_models = [] + for rank in range(self.world_size): + with keras.device(self.devices[rank]): + sharded_model, modified_params = make_parameter_sharded_model( + self.original_model, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + self.assertIsInstance(sharded_model, keras.Model) + self.assertIn("simple_mlp.up_proj.kernel", modified_params) + self.assertIn("simple_mlp.down_proj.kernel", modified_params) + sharded_models.append(sharded_model) + self.assertEqual( + len(self.original_model.weights), len(sharded_models[0].weights) + ) + + def test_sharded_weight_shapes(self): + rank = 0 + with keras.device(self.devices[rank]): + sharded_model, _ = make_parameter_sharded_model( + self.original_model, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + original_weights_dict = {w.path: w for w in self.original_model.weights} + sharded_weights_dict = { + w.name if isinstance(w, ShardedWeight) else w.path: w + for w in sharded_model.weights + } + orig_up_kernel = original_weights_dict["up_proj/kernel"] + shard_up_kernel = sharded_weights_dict["simple_mlp.up_proj.kernel"] + self.assertEqual(shard_up_kernel.shape[0], orig_up_kernel.shape[0]) + self.assertEqual( + shard_up_kernel.shape[1], + orig_up_kernel.shape[1] // self.world_size, + ) + orig_down_kernel = original_weights_dict["down_proj/kernel"] + shard_down_kernel = sharded_weights_dict["simple_mlp.down_proj.kernel"] + self.assertEqual( + shard_down_kernel.shape[0], + orig_down_kernel.shape[0] // self.world_size, + ) + self.assertEqual(shard_down_kernel.shape[1], orig_down_kernel.shape[1]) + + def test_forward_pass_correctness(self): + expected_output = self.original_model(self.input_data) + sharded_outputs = [] + original_weights = self.original_model.get_weights() + for rank in range(self.world_size): + with keras.device(self.devices[rank]): + cloned_original = keras.models.clone_model(self.original_model) + cloned_original.set_weights(original_weights) + sharded_model, _ = make_parameter_sharded_model( + cloned_original, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + output = sharded_model(self.input_data) + sharded_outputs.append(output) + reconstructed_output = ( + keras.ops.sum(keras.ops.stack(sharded_outputs), axis=0) + / self.world_size + ) + + self.assertAllClose( + expected_output, reconstructed_output, atol=1e-5, rtol=1e-5 + ) diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding.py b/keras/src/distribution/tensor_parallel/parameter_sharding.py new file mode 100644 index 000000000000..500843c0ec1d --- /dev/null +++ b/keras/src/distribution/tensor_parallel/parameter_sharding.py @@ -0,0 +1,514 @@ +import re +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np + +from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, +) +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.state_action_keras import ( + StateActionKeras, +) + + +class ShardedWeight: + def __init__(self, tensor_shard, name, trainable=True): + import keras + + self._variable = keras.Variable( + initializer=tensor_shard, trainable=trainable, name=name + ) + self.regularizer = None + + @property + def name(self): + """Returns the name of the underlying variable.""" + return self._variable.name + + @property + def trainable(self): + """Returns whether the variable is trainable.""" + return self._variable.trainable + + @property + def shape(self): + """Returns the shape of the variable.""" + return self._variable.shape + + @property + def dtype(self): + """Returns the dtype of the underlying variable.""" + return self._variable.dtype + + @property + def variable(self): + """Provides direct access to the underlying tf.Variable.""" + return self._variable + + def numpy(self): + """Returns the value of the variable as a NumPy array.""" + return self._variable.numpy() + + def num_elements(self): + """Returns the total number of elements in the tensor.""" + import keras + + return keras.ops.size(self._variable) + + def __repr__(self): + return ( + f"" + ) + + +class ParameterShardingStrategy: + """ + Parameter-level sharding strategy that works with any Keras model. + Instead of rebuilding the model, we shard only the weights and handle + communication during forward/backward passes. + """ + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + self.sharded_weights = {} + self.original_weights = {} + self.weight_mapping = {} + self.sharded_weights_by_id = {} + + def shard_model_parameters( + self, + model, + config: ConfigKeras, + communicator: TensorParallelCommunicator, + device_id: Any, + ): + """ + Shard model parameters without rebuilding the model structure. + """ + ParameterShardedModel = _define_parameter_sharded_model() + + self._store_original_weights(model) + modified_parameters = set() + + for pattern, action in config.state_rules.items(): + if isinstance(action, StateActionKeras): + matching_params = self._find_matching_parameters(model, pattern) + + for param_name, param in matching_params: + try: + param_id = id(param.experimental_ref()) + except AttributeError: + param_id = id(param) + + if param_id in self.sharded_weights_by_id: + self.sharded_weights[param_name] = ( + self.sharded_weights_by_id[param_id] + ) + + existing_param_name = "unknown" + for name, shard in self.sharded_weights.items(): + if shard is self.sharded_weights_by_id[param_id]: + existing_param_name = name + break + + self.weight_mapping[param_name] = self.weight_mapping[ + existing_param_name + ] + modified_parameters.add(param_name) + continue + + sharded_param = action(param, self.rank) + + self.sharded_weights[param_name] = sharded_param + self.sharded_weights_by_id[param_id] = sharded_param + + self.weight_mapping[param_name] = { + "original_shape": param.shape, + "sharded_shape": sharded_param.shape, + "action": action, + } + + modified_parameters.add(param_name) + + sharded_model = ParameterShardedModel( + original_model=model, + sharding_strategy=self, + communicator=communicator, + config=config, + device_id=device_id, + ) + + return sharded_model, modified_parameters + + def _store_original_weights(self, model): + """Store original weights for reference.""" + from keras.src import layers + + def find_weights_recursive( + current_layer: layers.Layer, prefix: str = "" + ): + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if hasattr(current_layer, "weights") and current_layer.weights: + for weight in current_layer.weights: + cleaned_name = weight.name.split("/")[-1].split(":")[0] + param_name = f"{full_name}.{cleaned_name}" + self.original_weights[param_name] = weight.numpy() + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + find_weights_recursive(sub_layer, full_name) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + try: + attr = getattr(current_layer, attr_name) + except Exception: + continue + if isinstance(attr, layers.Layer) and attr is not current_layer: + find_weights_recursive(attr, full_name) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + find_weights_recursive(item, full_name) + + for layer in model.layers: + find_weights_recursive(layer, prefix="") + + def _find_matching_parameters( + self, model, pattern: str + ) -> List[Tuple[str, Any]]: + """ + Find parameters that match the given pattern using smart recursion. + """ + from keras.src import layers + + matching_params = [] + processed_layers = set() + + def search_layer_recursive( + current_layer: layers.Layer, prefix: str = "" + ): + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if hasattr(current_layer, "weights") and current_layer.weights: + for weight in current_layer.weights: + cleaned_weight_name = weight.name.split("/")[-1].split(":")[ + 0 + ] + param_name = f"{full_name}.{cleaned_weight_name}" + + if re.match(pattern, param_name): + matching_params.append((param_name, weight)) + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + search_layer_recursive(sub_layer, full_name) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + + try: + attr = getattr(current_layer, attr_name) + except Exception: + continue + + if isinstance(attr, layers.Layer) and attr is not current_layer: + search_layer_recursive(attr, full_name) + + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + search_layer_recursive(item, full_name) + + search_layer_recursive(model, prefix="") + + return matching_params + + def get_sharded_weight(self, param_name: str) -> Optional[np.ndarray]: + """Get sharded weight for a parameter.""" + if param_name in self.sharded_weights: + return self.sharded_weights[param_name].numpy() + return None + + def get_weight_info(self, param_name: str) -> Optional[Dict]: + """Get information about a sharded weight.""" + return self.weight_mapping.get(param_name) + + +def _define_parameter_sharded_model(): + """ + Factory function to define and return the ParameterShardedModel class. + """ + from keras.src.models import Model + + class ParameterShardedModel(Model): + def __init__( + self, + original_model: Model, + sharding_strategy: ParameterShardingStrategy, + communicator: TensorParallelCommunicator, + config: ConfigKeras, + device_id: Any, + ): + super().__init__() + + self.original_model = original_model + self.sharding_strategy = sharding_strategy + self.config = config + self.communicator = communicator + self._device = device_id + + self._build_and_cache_weights() + + if original_model.inputs: + self.build(original_model.inputs[0].shape) + + @property + def device(self): + return self._device + + def train_step(self, data): + """ + Custom train_step for the parameter-sharded model. + + This override includes a gradient synchronization (all-reduce) step, + which is essential for the backward pass in tensor parallelism. + """ + import tensorflow as tf + + import keras + + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + synced_gradients = self.communicator.all_reduce( + gradients, op="sum", axis_name="model" + ) + self.optimizer.apply_gradients( + zip(synced_gradients, trainable_vars) + ) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + + return {m.name: m.result() for m in self.metrics} + + def _build_and_cache_weights(self): + weights_list = [] + + sharded_weight_ids = set( + self.sharding_strategy.sharded_weights_by_id.keys() + ) + + for ( + param_name, + sharded_tensor, + ) in self.sharding_strategy.sharded_weights.items(): + weights_list.append(ShardedWeight(sharded_tensor, param_name)) + + unsharded_count = 0 + for weight in self.original_model.weights: + try: + weight_id = id(weight.experimental_ref()) + except AttributeError: + weight_id = id(weight) + + if weight_id not in sharded_weight_ids: + weights_list.append(weight) + unsharded_count += 1 + + self._weights_list = weights_list + + @property + def weights(self): + return self._weights_list + + def call(self, inputs, training=None, mask=None): + from keras.src import layers + + tensor_cache = {} + current_tensor = inputs + + for layer in self.original_model.layers: + if isinstance(layer, layers.InputLayer): + continue + + if isinstance(layer, layers.Add): + try: + if "feedforward_output" in layer.name: + residual_source_name = layer.name.replace( + "feedforward_output", "self_attention_output" + ) + elif "self_attention_output" in layer.name: + residual_source_name = layer.name.replace( + "self_attention_output", "input_layer_norm" + ) + else: + residual_source_name = None + + if ( + residual_source_name + and residual_source_name in tensor_cache + ): + layer_inputs = [ + current_tensor, + tensor_cache[residual_source_name], + ] + else: + layer_inputs = [current_tensor, current_tensor] + except Exception: + layer_inputs = [current_tensor, current_tensor] + else: + layer_inputs = current_tensor + + if ( + "attention_output" in layer.name + or "feedforward_output" in layer.name + ): + tensor_cache[layer.name] = current_tensor + + current_tensor = layer(layer_inputs, training=training) + + layer_path = layer.path + + output_rule = None + for pattern, rule in self.config.output_rules.items(): + if re.search(pattern, layer_path): + output_rule = rule.get(0) + break + + if output_rule: + current_tensor = self._apply_communication( + current_tensor, layer.name, output_rule + ) + + return current_tensor + + def _apply_communication(self, sharded_output, layer_name, rule): + """Applies communication using the high-level communicator.""" + op_name = str(rule).lower() + + if "sum" in op_name or "allreduce" in op_name: + return self.communicator.forward_row_parallel( + sharded_output, op="sum", axis_name="model" + ) + + elif "gather" in op_name: + try: + dim = int(op_name.split(" ")[-1]) + except (ValueError, IndexError): + dim = -1 + return self.communicator.forward_column_parallel( + sharded_output, dim=dim, axis_name="model" + ) + + elif hasattr(rule, "__call__"): + return rule(sharded_output) + + else: + return sharded_output + + def get_config(self): + """Get model configuration.""" + return self.original_model.get_config() + + @classmethod + def from_config(cls, config, custom_objects=None): + """Create model from config.""" + return cls(**config) + + return ParameterShardedModel + + +def make_parameter_sharded_model( + module, config: ConfigKeras, rank: int, world_size: int, device_id: Any +): + """ + Create a parameter-sharded version of a Keras model. + """ + communicator = TensorParallelCommunicator(world_size=world_size, rank=rank) + sharding_strategy = ParameterShardingStrategy(world_size, rank) + + sharded_model, modified_parameters = ( + sharding_strategy.shard_model_parameters( + module, config, communicator, device_id + ) + ) + + return sharded_model, modified_parameters + + +def apply_parameter_sharding_to_existing_model( + model, config: ConfigKeras, rank: int, world_size: int +): + """ + Apply parameter sharding to an existing model without creating a new one. + """ + + sharding_strategy = ParameterShardingStrategy(world_size, rank) + for pattern, action in config.state_rules.items(): + if isinstance(action, StateActionKeras): + matching_params = sharding_strategy._find_matching_parameters( + model, pattern + ) + + for param_name, param in matching_params: + try: + param_id = id(param.experimental_ref()) + except AttributeError: + param_id = id(param) + + if param_id in sharding_strategy.sharded_weights_by_id: + sharding_strategy.sharded_weights[param_name] = ( + sharding_strategy.sharded_weights_by_id[param_id] + ) + existing_param_name = next( + k + for k, v in sharding_strategy.sharded_weights.items() + if v + is sharding_strategy.sharded_weights_by_id[param_id] + ) + sharding_strategy.weight_mapping[param_name] = ( + sharding_strategy.weight_mapping[existing_param_name] + ) + continue + + sharded_param = action(param, rank) + + sharding_strategy.sharded_weights[param_name] = sharded_param + sharding_strategy.sharded_weights_by_id[param_id] = ( + sharded_param + ) + + sharding_strategy.weight_mapping[param_name] = { + "original_shape": param.shape, + "sharded_shape": sharded_param.shape, + "action": action, + } + + model._tensor_parallel_sharding = sharding_strategy + + return model From c01ec8a81c0c28bacbe9054591bf699b20050f18 Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 7 Oct 2025 10:12:12 +0530 Subject: [PATCH 2/8] Added docstrings to parameter_shardimg --- .../tensor_parallel/parameter_sharding.py | 231 +++++++++++++++--- 1 file changed, 193 insertions(+), 38 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding.py b/keras/src/distribution/tensor_parallel/parameter_sharding.py index 500843c0ec1d..30a16e9c63fe 100644 --- a/keras/src/distribution/tensor_parallel/parameter_sharding.py +++ b/keras/src/distribution/tensor_parallel/parameter_sharding.py @@ -17,6 +17,18 @@ class ShardedWeight: + """A wrapper class for a sharded Keras Variable. + + This class holds a shard of a model weight as a `keras.Variable` and + provides an interface similar to the original variable, allowing it to be + seamlessly integrated into the Keras ecosystem. + + Args: + tensor_shard: The tensor slice (shard) of the weight. + name (str): The name for the underlying `keras.Variable`. + trainable (bool): Whether the variable is trainable. + """ + def __init__(self, tensor_shard, name, trainable=True): import keras @@ -26,41 +38,42 @@ def __init__(self, tensor_shard, name, trainable=True): self.regularizer = None @property - def name(self): + def name(self) -> str: """Returns the name of the underlying variable.""" return self._variable.name @property - def trainable(self): + def trainable(self) -> bool: """Returns whether the variable is trainable.""" return self._variable.trainable @property - def shape(self): + def shape(self) -> Tuple[int, ...]: """Returns the shape of the variable.""" return self._variable.shape @property - def dtype(self): + def dtype(self) -> any: """Returns the dtype of the underlying variable.""" return self._variable.dtype @property def variable(self): - """Provides direct access to the underlying tf.Variable.""" + """Provides direct access to the underlying `keras.Variable`.""" return self._variable - def numpy(self): + def numpy(self) -> np.ndarray: """Returns the value of the variable as a NumPy array.""" return self._variable.numpy() - def num_elements(self): + def num_elements(self) -> int: """Returns the total number of elements in the tensor.""" import keras return keras.ops.size(self._variable) - def __repr__(self): + def __repr__(self) -> str: + """Provides a developer-friendly string representation.""" return ( f"" @@ -68,19 +81,26 @@ def __repr__(self): class ParameterShardingStrategy: - """ - Parameter-level sharding strategy that works with any Keras model. - Instead of rebuilding the model, we shard only the weights and handle - communication during forward/backward passes. + """Manages the sharding of model parameters for tensor parallelism. + + This strategy identifies weights in a Keras model based on configuration + rules, shards them, and stores the sharded weights and metadata. It's + designed to modify a model's parameters without altering its architecture. + + Args: + world_size (int): The total number of devices (workers) in the + parallel computation group. + rank (int): The unique identifier for the current device (worker), + from 0 to `world_size - 1`. """ def __init__(self, world_size: int, rank: int): self.world_size = world_size self.rank = rank - self.sharded_weights = {} - self.original_weights = {} - self.weight_mapping = {} - self.sharded_weights_by_id = {} + self.sharded_weights = {} # Maps param name to its sharded tensor + self.original_weights = {} # Stores a copy of original weights + self.weight_mapping = {} # Maps param name to sharding info + self.sharded_weights_by_id = {} # Maps original weight ID to shard def shard_model_parameters( self, @@ -88,9 +108,26 @@ def shard_model_parameters( config: ConfigKeras, communicator: TensorParallelCommunicator, device_id: Any, - ): - """ - Shard model parameters without rebuilding the model structure. + ) -> Tuple[Any, set]: + """Shards model parameters and wraps the model for tensor parallelism. + + This method iterates through the model's parameters, applies sharding + rules defined in the config, and creates a `ParameterShardedModel` + which handles the forward pass with necessary communication primitives. + + Args: + model: The original Keras model to be sharded. + config (ConfigKeras): The configuration object containing sharding + rules (`state_rules` and `output_rules`). + communicator (TensorParallelCommunicator): The communicator for + handling cross-device data transfer (e.g., all-gather). + device_id (Any): The device identifier where the model will run. + + Returns: + A tuple containing: + - ParameterShardedModel: The new model wrapped for tensor + parallelism. + - set: A set of names of the parameters that were sharded. """ ParameterShardedModel = _define_parameter_sharded_model() @@ -148,12 +185,13 @@ def shard_model_parameters( return sharded_model, modified_parameters def _store_original_weights(self, model): - """Store original weights for reference.""" + """Recursively traverses the model and stores original weights.""" from keras.src import layers def find_weights_recursive( current_layer: layers.Layer, prefix: str = "" ): + """Helper to recursively find and store weights.""" name = current_layer.name full_name = f"{prefix}.{name}" if prefix else name @@ -187,8 +225,19 @@ def find_weights_recursive( def _find_matching_parameters( self, model, pattern: str ) -> List[Tuple[str, Any]]: - """ - Find parameters that match the given pattern using smart recursion. + """Finds model parameters whose names match a given regex pattern. + + This method recursively searches through the model's layers and + sub-layers to find all weights, then filters them based on the pattern. + + Args: + model: The Keras model to search within. + pattern (str): A regular expression to match against parameter + names. + + Returns: + A list of tuples, where each tuple contains the parameter's full + name and the parameter object itself. """ from keras.src import layers @@ -198,6 +247,7 @@ def _find_matching_parameters( def search_layer_recursive( current_layer: layers.Layer, prefix: str = "" ): + """Helper to recursively find matching parameters.""" if id(current_layer) in processed_layers: return processed_layers.add(id(current_layer)) @@ -241,23 +291,61 @@ def search_layer_recursive( return matching_params def get_sharded_weight(self, param_name: str) -> Optional[np.ndarray]: - """Get sharded weight for a parameter.""" + """Retrieves the sharded weight for a given parameter name. + + Args: + param_name (str): The name of the parameter. + + Returns: + The sharded weight as a NumPy array if it exists, otherwise None. + """ if param_name in self.sharded_weights: return self.sharded_weights[param_name].numpy() return None def get_weight_info(self, param_name: str) -> Optional[Dict]: - """Get information about a sharded weight.""" + """Retrieves sharding information for a specific parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + A dictionary containing metadata about the sharding (e.g., + original shape, sharded shape, action) if it exists, + otherwise None. + """ return self.weight_mapping.get(param_name) def _define_parameter_sharded_model(): - """ - Factory function to define and return the ParameterShardedModel class. + """Factory function to define and return the ParameterShardedModel class. + + This approach encapsulates the class definition and avoids potential + circular dependencies, while also keeping the related logic organized. + + Returns: + The `ParameterShardedModel` class. """ from keras.src.models import Model class ParameterShardedModel(Model): + """A Keras Model wrapper for executing a parameter-sharded model. + + This model overrides the `call` and `train_step` methods to inject + the necessary communication operations (e.g., all-reduce, all-gather) + for tensor parallelism during the forward and backward passes. + + Args: + original_model (Model): The original, non-sharded Keras model. + sharding_strategy (ParameterShardingStrategy): The strategy + instance that holds the sharded weights and metadata. + communicator (TensorParallelCommunicator): The object responsible + for inter-device communication. + config (ConfigKeras): The configuration with sharding and + communication rules. + device_id (Any): The identifier of the device this model runs on. + """ + def __init__( self, original_model: Model, @@ -281,14 +369,22 @@ def __init__( @property def device(self): + """Returns the device identifier for this model instance.""" return self._device def train_step(self, data): - """ - Custom train_step for the parameter-sharded model. + """Custom training step for the parameter-sharded model. - This override includes a gradient synchronization (all-reduce) step, - which is essential for the backward pass in tensor parallelism. + This method performs a standard forward and backward pass but + adds a crucial gradient synchronization step (`all_reduce`) before + applying gradients. This ensures that each device updates its + local weight shards using gradients computed from all devices. + + Args: + data: A tuple of (x, y, sample_weight) as passed by `fit()`. + + Returns: + A dictionary mapping metric names to their current values. """ import tensorflow as tf @@ -317,6 +413,12 @@ def train_step(self, data): return {m.name: m.result() for m in self.metrics} def _build_and_cache_weights(self): + """Constructs a unified list of weights for the model. + + This list includes the custom `ShardedWeight` objects for parameters + that were sharded, and the original `keras.Variable` objects for + those that were not. + """ weights_list = [] sharded_weight_ids = set( @@ -344,9 +446,25 @@ def _build_and_cache_weights(self): @property def weights(self): + """Returns the combined list of sharded and non-sharded weights.""" return self._weights_list def call(self, inputs, training=None, mask=None): + """Defines the forward pass of the model. + + This method executes the layers of the original model sequentially. + After each layer's execution, it checks if an output communication + rule applies (e.g., for row-parallel or column-parallel layers) + and triggers the corresponding communication operation. + + Args: + inputs: Input tensor(s). + training (bool): Indicates if the model is in training mode. + mask: A mask or list of masks. + + Returns: + The output tensor of the model. + """ from keras.src import layers tensor_cache = {} @@ -408,7 +526,16 @@ def call(self, inputs, training=None, mask=None): return current_tensor def _apply_communication(self, sharded_output, layer_name, rule): - """Applies communication using the high-level communicator.""" + """Applies a communication primitive based on a rule. + + Args: + sharded_output: The output tensor from a layer. + layer_name (str): The name of the layer. + rule: The communication rule from the config. + + Returns: + The tensor after the communication operation has been applied. + """ op_name = str(rule).lower() if "sum" in op_name or "allreduce" in op_name: @@ -432,12 +559,12 @@ def _apply_communication(self, sharded_output, layer_name, rule): return sharded_output def get_config(self): - """Get model configuration.""" + """Serializes the model's configuration.""" return self.original_model.get_config() @classmethod def from_config(cls, config, custom_objects=None): - """Create model from config.""" + """Creates a model from its configuration.""" return cls(**config) return ParameterShardedModel @@ -445,9 +572,23 @@ def from_config(cls, config, custom_objects=None): def make_parameter_sharded_model( module, config: ConfigKeras, rank: int, world_size: int, device_id: Any -): - """ - Create a parameter-sharded version of a Keras model. +) -> Tuple[Any, set]: + """Creates a parameter-sharded version of a Keras model. + + This is a high-level factory function that orchestrates the creation of + the communicator, the sharding strategy, and the final sharded model. + + Args: + module: The Keras model to be sharded. + config (ConfigKeras): Configuration object with sharding rules. + rank (int): The rank of the current process. + world_size (int): The total number of processes. + device_id (Any): The device on which the model will be placed. + + Returns: + A tuple containing: + - The newly created `ParameterShardedModel`. + - A set of names for the parameters that were modified. """ communicator = TensorParallelCommunicator(world_size=world_size, rank=rank) sharding_strategy = ParameterShardingStrategy(world_size, rank) @@ -464,8 +605,22 @@ def make_parameter_sharded_model( def apply_parameter_sharding_to_existing_model( model, config: ConfigKeras, rank: int, world_size: int ): - """ - Apply parameter sharding to an existing model without creating a new one. + """Applies parameter sharding directly to an existing model instance. + + This function modifies a model in-place. Instead of returning a new + wrapped model, it shards the weights and attaches the sharding strategy + to the original model object. This is useful when the model's execution + logic is handled externally. + + Args: + model: The Keras model to modify. + config (ConfigKeras): Configuration object with sharding rules. + rank (int): The rank of the current process. + world_size (int): The total number of processes. + + Returns: + The modified model with an attached `_tensor_parallel_sharding` + strategy attribute. """ sharding_strategy = ParameterShardingStrategy(world_size, rank) From c22facf65d05ff52fee4a69982d3f92719292c1f Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 7 Oct 2025 10:15:29 +0530 Subject: [PATCH 3/8] Added docstrings --- .../parameter_shardimg_test.py | 60 +++++++++++++++++-- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py b/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py index 348350299865..681d6724c325 100644 --- a/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py +++ b/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py @@ -24,7 +24,15 @@ reason="This test is JAX-specific.", ) def _create_simple_mlp(): - """Creates a simple, unsharded Keras MLP model for testing.""" + """Creates a simple, unsharded Keras MLP model for testing. + + This model serves as the baseline for sharding tests. It consists of + an input layer, a hidden dense layer with a ReLU activation, and an + output dense layer. + + Returns: + A `keras.Model` instance. + """ inputs = keras.Input(shape=(16,), name="input") x = keras.layers.Dense(32, use_bias=True, name="up_proj")(inputs) x = keras.layers.Activation("relu")(x) @@ -33,39 +41,58 @@ def _create_simple_mlp(): class ParameterShardingTest(TestCase): + """Test suite for parameter sharding functionality. + + This class tests the creation of sharded models, the correctness of + sharded weight shapes, and the numerical accuracy of the forward pass + of a sharded model compared to its original, unsharded counterpart. + """ + def setUp(self): + """Sets up the testing environment before each test case.""" super().setUp() - import logging - - logging.getLogger().setLevel(logging.ERROR) self.world_size = 2 all_devices = distribution.list_devices() self.devices = all_devices[: self.world_size] if len(self.devices) < self.world_size: self.skipTest( - f"""Not enough devices to run TP test. + f"""Not enough devices to run TP test. Found {len(self.devices)}, need {self.world_size}""" ) + # Create the original model and the sharding configuration. self.original_model = _create_simple_mlp() self.original_model.build(input_shape=(None, 16)) self.tp_config = ConfigKeras( state_rules={ + # Rule to split the first dense layer's kernel along the output + # dimension (column-wise). re.escape("simple_mlp.up_proj.kernel"): SplitKeras( self.world_size, dim=1 ), + # Rule to split the second dense layer's kernel along the input + # dimension (row-wise). re.escape("simple_mlp.down_proj.kernel"): SplitKeras( self.world_size, dim=0 ), }, output_rules={}, ) + # Generate dummy data for testing forward passes. self.input_data = np.random.rand(4, 16).astype("float32") self.labels = np.random.rand(4, 8).astype("float32") def test_model_sharding_creation_and_weight_counts(self): + """Tests if sharded models are created correctly. + + Verifies that: + 1. `make_parameter_sharded_model` returns a valid Keras model. + 2. The set of modified parameters correctly identifies sharded layers. + 3. The total number of weights in the sharded model matches the + original model, ensuring no weights are lost or added. + """ sharded_models = [] for rank in range(self.world_size): with keras.device(self.devices[rank]): @@ -80,11 +107,19 @@ def test_model_sharding_creation_and_weight_counts(self): self.assertIn("simple_mlp.up_proj.kernel", modified_params) self.assertIn("simple_mlp.down_proj.kernel", modified_params) sharded_models.append(sharded_model) + + # The sharded model should have the same number of weight objects. self.assertEqual( len(self.original_model.weights), len(sharded_models[0].weights) ) def test_sharded_weight_shapes(self): + """Validates the shapes of the weights after sharding. + + This test ensures that the dimensions specified in the sharding rules + are correctly split by the world size, while other dimensions remain + unchanged. + """ rank = 0 with keras.device(self.devices[rank]): sharded_model, _ = make_parameter_sharded_model( @@ -94,11 +129,14 @@ def test_sharded_weight_shapes(self): world_size=self.world_size, device_id=self.devices[rank], ) + original_weights_dict = {w.path: w for w in self.original_model.weights} sharded_weights_dict = { w.name if isinstance(w, ShardedWeight) else w.path: w for w in sharded_model.weights } + + # Check the shape of the column-split kernel. orig_up_kernel = original_weights_dict["up_proj/kernel"] shard_up_kernel = sharded_weights_dict["simple_mlp.up_proj.kernel"] self.assertEqual(shard_up_kernel.shape[0], orig_up_kernel.shape[0]) @@ -106,6 +144,8 @@ def test_sharded_weight_shapes(self): shard_up_kernel.shape[1], orig_up_kernel.shape[1] // self.world_size, ) + + # Check the shape of the row-split kernel. orig_down_kernel = original_weights_dict["down_proj/kernel"] shard_down_kernel = sharded_weights_dict["simple_mlp.down_proj.kernel"] self.assertEqual( @@ -115,13 +155,22 @@ def test_sharded_weight_shapes(self): self.assertEqual(shard_down_kernel.shape[1], orig_down_kernel.shape[1]) def test_forward_pass_correctness(self): + """Checks if the sharded model's output matches the original. + + This test performs a forward pass on both the original model and the + sharded models. It then reconstructs the output from the sharded + models and asserts that it is numerically close to the original + model's output. This serves as an end-to-end correctness check. + """ expected_output = self.original_model(self.input_data) sharded_outputs = [] original_weights = self.original_model.get_weights() + for rank in range(self.world_size): with keras.device(self.devices[rank]): cloned_original = keras.models.clone_model(self.original_model) cloned_original.set_weights(original_weights) + sharded_model, _ = make_parameter_sharded_model( cloned_original, self.tp_config, @@ -131,6 +180,7 @@ def test_forward_pass_correctness(self): ) output = sharded_model(self.input_data) sharded_outputs.append(output) + reconstructed_output = ( keras.ops.sum(keras.ops.stack(sharded_outputs), axis=0) / self.world_size From 276530a4f61c50c246403d8d378e24dee7b2a8c3 Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 7 Oct 2025 10:56:24 +0530 Subject: [PATCH 4/8] Torch backend added --- keras/src/backend/__init__.py | 4 + keras/src/backend/torch/__init__.py | 1 + .../src/backend/torch/distributed_backend.py | 257 ++++++++++++++++++ .../backend/torch/distributed_backend_test.py | 133 +++++++++ 4 files changed, 395 insertions(+) create mode 100644 keras/src/backend/torch/distributed_backend.py create mode 100644 keras/src/backend/torch/distributed_backend_test.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..fe393cf08abd 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,8 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -50,11 +52,13 @@ from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index 371a62cd0f52..c8095d01654e 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -16,6 +16,7 @@ from keras.src.backend.common.name_scope import name_scope from keras.src.backend.torch import core +from keras.src.backend.torch import distributed_backend from keras.src.backend.torch import image from keras.src.backend.torch import linalg from keras.src.backend.torch import math diff --git a/keras/src/backend/torch/distributed_backend.py b/keras/src/backend/torch/distributed_backend.py new file mode 100644 index 000000000000..16a5ced75d8e --- /dev/null +++ b/keras/src/backend/torch/distributed_backend.py @@ -0,0 +1,257 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal + +import torch +import torch.distributed as dist + + +def compute_gradients( + loss: torch.Tensor, trainable_vars: List[torch.Tensor] +) -> List[torch.Tensor]: + """Computes gradients of the loss with respect to trainable variables. + + This function leverages PyTorch's `autograd.grad` for a stateless, + functional approach similar to `jax.grad`. + + Args: + loss (torch.Tensor): The loss value for which to compute gradients. + trainable_vars (List[torch.Tensor]): A list of variables (tensors with + `requires_grad=True`) to compute gradients with respect to. + + Returns: + List[torch.Tensor]: A list of gradients corresponding to the + trainable variables. + """ + return list(torch.autograd.grad(loss, trainable_vars)) + + +def apply_gradients( + gradients: List[torch.Tensor], + trainable_vars: List[torch.Tensor], + learning_rate: float = 0.001, +) -> List[torch.Tensor]: + """Applies gradients and returns the updated variables. + + Updates are performed in-place within a `torch.no_grad()` context + to prevent the update operation from being part of the computation graph. + """ + with torch.no_grad(): + updated_vars = [] + for grad, var in zip(gradients, trainable_vars): + if grad is not None: + var.sub_(learning_rate * grad) + updated_vars.append(var) + return updated_vars + + +def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: + """Creates a configuration dictionary for a PyTorch optimizer. + + This function returns a dictionary containing the optimizer's configuration, + maintaining a consistent interface with the JAX backend. The user is + expected to instantiate the optimizer from this config. + + Args: + optimizer_class (str): The name of the optimizer to create (e.g., + `"adam"`, `"sgd"`). + **kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`). + + Returns: + Dict[str, Any]: A dictionary representing the optimizer configuration. + """ + config = kwargs.copy() + config["name"] = optimizer_class.lower() + config.setdefault("learning_rate", 0.001) + return config + + +def get_device_info() -> Dict[str, Any]: + """Retrieves information about the available PyTorch devices. + + Returns: + Dict[str, Any]: A dictionary containing the backend name, a list of + available device strings, and the total device count. + """ + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + devices = [torch.cuda.get_device_name(i) for i in range(device_count)] + else: + device_count = 1 + devices = ["cpu"] + return { + "backend": "pytorch", + "devices": devices, + "device_count": device_count, + } + + +def is_multi_device_capable() -> bool: + """Checks if more than one CUDA device is available. + + Returns: + bool: `True` if PyTorch reports more than one CUDA device, `False` + otherwise. + """ + return torch.cuda.device_count() > 1 + + +def get_communication_ops() -> Dict[str, Callable]: + """Provides a dictionary of PyTorch collective communication operations. + + These operations rely on the `torch.distributed` package. They are + designed to work in a multi-process, multi-device environment. If the + distributed package is not initialized, they provide a sensible fallback + for single-device execution. + + Returns: + Dict[str, Callable]: A dictionary mapping operation names to their + PyTorch implementations. + """ + + def _is_distributed() -> bool: + """Checks if the default process group is initialized.""" + return dist.is_available() and dist.is_initialized() + + def all_reduce( + x: torch.Tensor, + op: Literal["sum", "mean"] = "sum", + ) -> torch.Tensor: + """Reduces a tensor across all devices. + + Args: + x (torch.Tensor): The tensor to reduce. + op (Literal["sum", "mean"], optional): The reduction operation. + Defaults to "sum". + + Returns: + torch.Tensor: The reduced tensor. + """ + if not _is_distributed(): + world_size = ( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ) + if world_size <= 1: + return x + if op == "sum": + return x * float(world_size) + elif op == "mean": + return x + else: + raise ValueError(f"Unsupported all_reduce op: {op}") + + reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get( + op + ) + if reduce_op is None: + raise ValueError(f"Unsupported all_reduce op: {op}") + + result = x.clone() + dist.all_reduce(result, op=reduce_op) + return result + + def all_gather(x: torch.Tensor, axis: int = 0) -> torch.Tensor: + """Gathers tensors from all devices and concatenates them. + + Args: + x (torch.Tensor): The local tensor to gather. + axis (int, optional): The axis along which to concatenate. + Defaults to 0. + + Returns: + torch.Tensor: The concatenated tensor from all devices. + """ + if not _is_distributed(): + world_size = ( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ) + if world_size <= 1: + return x + return torch.cat([x] * world_size, dim=axis) + + world_size = dist.get_world_size() + tensor_list = [torch.empty_like(x) for _ in range(world_size)] + dist.all_gather(tensor_list, x) + return torch.cat(tensor_list, dim=axis) + + def broadcast(x: torch.Tensor, root: int = 0) -> torch.Tensor: + """Broadcasts a tensor from a root device to all other devices. + + Args: + x (torch.Tensor): The tensor to broadcast. + root (int, optional): The rank of the source device. Defaults to 0. + + Returns: + torch.Tensor: The tensor received from the root device. + """ + if not _is_distributed(): + return x + + # `dist.broadcast` is in-place. + dist.broadcast(x, src=root) + return x + + def scatter( + x: torch.Tensor, + root: int = 0, + axis: int = 0, + ) -> torch.Tensor: + """Scatters a tensor from a root device to all devices. + + Note: The current implementation of `dist.scatter` requires the input + tensor `x` to be organized differently for the root process. This + wrapper simplifies it by handling the splitting automatically on the + root process. + + Args: + x (torch.Tensor): The tensor on the root device to be scattered. + root (int, optional): The rank of the device holding the tensor. + Defaults to 0. + axis (int, optional): The axis along which to split the tensor. + Defaults to 0. + + Returns: + torch.Tensor: The chunk of the tensor for the local device. + """ + if not _is_distributed(): + world_size = ( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ) + if world_size <= 1: + return x + if x.shape[axis] % world_size != 0: + raise ValueError( + f"Tensor with shape {x.shape} cannot be scattered along " + f"axis {axis} across {world_size} devices." + ) + return torch.chunk(x, world_size, dim=axis)[0] + + world_size = dist.get_world_size() + rank = dist.get_rank() + + if x.shape[axis] % world_size != 0: + raise ValueError( + f"Tensor with shape {x.shape} cannot be scattered along " + f"axis {axis} across {world_size} devices." + ) + + if rank == root: + scatter_list = list(torch.chunk(x, world_size, dim=axis)) + else: + scatter_list = None + + chunk_shape = list(x.shape) + chunk_shape[axis] //= world_size + local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device) + + dist.scatter(local_chunk, scatter_list, src=root) + return local_chunk + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + "broadcast": broadcast, + "scatter": scatter, + } diff --git a/keras/src/backend/torch/distributed_backend_test.py b/keras/src/backend/torch/distributed_backend_test.py new file mode 100644 index 000000000000..9aaf68f8a9f9 --- /dev/null +++ b/keras/src/backend/torch/distributed_backend_test.py @@ -0,0 +1,133 @@ +import os + +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import pytest +import torch + +from keras.src import backend +from keras.src.backend import distributed_backend + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Jax Backend specific test", +) +class TestPytorchDistributedFunctions: + """Unit tests for the PyTorch distributed backend standalone functions.""" + + def test_compute_gradients_computes_correctly(self): + """Test that compute_gradients returns correct gradients.""" + w = torch.tensor([2.0, 3.0], requires_grad=True) + b = torch.tensor(1.0, requires_grad=True) + x = torch.tensor([4.0, 5.0]) + y_true = torch.tensor(25.0) + + # loss = (w.x + b - y_true)^2 = ((2*4 + 3*5 + 1) - 25)^2 = (24-25)^2 = 1 + y_pred = torch.dot(w, x) + b + loss = (y_pred - y_true) ** 2 + + trainable_vars = [w, b] + gradients = distributed_backend.compute_gradients(loss, trainable_vars) + + # d_loss/d_w = 2*(y_pred - y_true)*x = 2*(-1)*[4, 5] = [-8, -10] + # d_loss/d_b = 2*(y_pred - y_true)*1 = 2*(-1)*1 = -2 + expected_grad_w = torch.tensor([-8.0, -10.0]) + expected_grad_b = torch.tensor(-2.0) + + assert len(gradients) == 2 + torch.testing.assert_close(gradients[0], expected_grad_w) + torch.testing.assert_close(gradients[1], expected_grad_b) + + def test_apply_gradients(self): + """Test the application of gradients to PyTorch tensors.""" + var1 = torch.tensor([1.0, 2.0], requires_grad=True) + var2 = torch.tensor(5.0, requires_grad=True) + trainable_vars = [var1, var2] + grad1 = torch.tensor([0.1, 0.2]) + grad2 = torch.tensor(0.5) + gradients = [grad1, grad2] + learning_rate = 0.1 + + original_var1 = var1.clone() + original_var2 = var2.clone() + + updated_vars = distributed_backend.apply_gradients( + gradients, trainable_vars, learning_rate + ) + + assert updated_vars[0] is var1 + assert updated_vars[1] is var2 + + expected_var1 = original_var1 - (grad1 * learning_rate) + expected_var2 = original_var2 - (grad2 * learning_rate) + torch.testing.assert_close(updated_vars[0], expected_var1) + torch.testing.assert_close(updated_vars[1], expected_var2) + + def test_create_optimizer(self): + """Test optimizer configuration creation.""" + adam_config = distributed_backend.create_optimizer( + "adam", learning_rate=0.01 + ) + assert isinstance(adam_config, dict) + assert adam_config["name"] == "adam" + assert adam_config["learning_rate"] == 0.01 + + sgd_config = distributed_backend.create_optimizer( + "sgd", learning_rate=0.1, momentum=0.9 + ) + assert isinstance(sgd_config, dict) + assert sgd_config["name"] == "sgd" + assert sgd_config["learning_rate"] == 0.1 + assert sgd_config["momentum"] == 0.9 + + def test_get_device_info(self): + """Test retrieving device information from the PyTorch backend.""" + info = distributed_backend.get_device_info() + assert info["backend"] == "pytorch" + assert isinstance(info["devices"], list) + assert isinstance(info["device_count"], int) + assert info["device_count"] > 0 + assert len(info["devices"]) == info["device_count"] + if torch.cuda.is_available(): + assert info["device_count"] == torch.cuda.device_count() + else: + assert info["device_count"] == 1 + assert info["devices"] == ["cpu"] + + def test_is_multi_device_capable(self): + """Test the boolean check for multi-device capability.""" + assert isinstance(distributed_backend.is_multi_device_capable(), bool) + + def test_communication_ops_simulation_logic(self): + """Test the simulated communication ops in a single-device context.""" + comm_ops = distributed_backend.get_communication_ops() + device_info = distributed_backend.get_device_info() + world_size = device_info.get("device_count", 1) + + # Test all_reduce + x_reduce = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + reduced = comm_ops["all_reduce"](x_reduce, op="sum") + expected_reduce = ( + x_reduce * float(world_size) if world_size > 1 else x_reduce + ) + torch.testing.assert_close(reduced, expected_reduce) + + # Test all_gather + x_gather = torch.tensor([[1.0, 2.0]]) + gathered = comm_ops["all_gather"](x_gather, axis=0) + expected_gather = torch.cat([x_gather] * world_size, dim=0) + torch.testing.assert_close(gathered, expected_gather) + + # Test broadcast + x_broadcast = torch.tensor([5.0, 6.0]) + broadcasted = comm_ops["broadcast"](x_broadcast) + torch.testing.assert_close(broadcasted, x_broadcast) + + # Test scatter + if world_size > 0: + scatter_data = torch.arange(world_size * 4, dtype=torch.float32) + x_scatter = scatter_data.reshape(world_size * 2, 2) + scattered = comm_ops["scatter"](x_scatter, axis=0) + expected_scatter = torch.chunk(x_scatter, world_size, dim=0)[0] + torch.testing.assert_close(scattered, expected_scatter) From 6c0189e4e5da183a6624715224db3580419531be Mon Sep 17 00:00:00 2001 From: Suhana Date: Tue, 7 Oct 2025 22:58:47 +0530 Subject: [PATCH 5/8] Added torch backend and fixed circular import issue --- keras/src/backend/__init__.py | 1 - keras/src/backend/torch/__init__.py | 1 + .../src/backend/torch/distributed_backend.py | 61 +-- .../backend/torch/distributed_backend_test.py | 4 - keras/src/backend/torch/distribution_lib.py | 413 ++++++++++++++++++ .../backend/torch/distribution_lib_test.py | 160 +++++++ .../tensor_parallel/autoconfig_test.py | 245 +++++++++++ .../tensor_parallel/communications_test.py | 84 ++++ .../tensor_parallel/config_test.py | 96 ++++ .../coordinated_optimizer_test.py | 180 ++++++++ ...img_test.py => parameter_sharding_test.py} | 65 +-- .../state_action_keras_test.py | 109 +++++ keras/src/layers/layer.py | 5 +- 13 files changed, 1312 insertions(+), 112 deletions(-) create mode 100644 keras/src/backend/torch/distribution_lib.py create mode 100644 keras/src/backend/torch/distribution_lib_test.py create mode 100644 keras/src/distribution/tensor_parallel/autoconfig_test.py create mode 100644 keras/src/distribution/tensor_parallel/communications_test.py create mode 100644 keras/src/distribution/tensor_parallel/config_test.py create mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py rename keras/src/distribution/tensor_parallel/{parameter_shardimg_test.py => parameter_sharding_test.py} (68%) create mode 100644 keras/src/distribution/tensor_parallel/state_action_keras_test.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index fe393cf08abd..dc93944dfd47 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -46,7 +46,6 @@ from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable - distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index c8095d01654e..a7d0405a5567 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -17,6 +17,7 @@ from keras.src.backend.common.name_scope import name_scope from keras.src.backend.torch import core from keras.src.backend.torch import distributed_backend +from keras.src.backend.torch import distribution_lib from keras.src.backend.torch import image from keras.src.backend.torch import linalg from keras.src.backend.torch import math diff --git a/keras/src/backend/torch/distributed_backend.py b/keras/src/backend/torch/distributed_backend.py index 16a5ced75d8e..432760339102 100644 --- a/keras/src/backend/torch/distributed_backend.py +++ b/keras/src/backend/torch/distributed_backend.py @@ -118,17 +118,9 @@ def _is_distributed() -> bool: def all_reduce( x: torch.Tensor, op: Literal["sum", "mean"] = "sum", + axis_name: str = None, ) -> torch.Tensor: - """Reduces a tensor across all devices. - - Args: - x (torch.Tensor): The tensor to reduce. - op (Literal["sum", "mean"], optional): The reduction operation. - Defaults to "sum". - - Returns: - torch.Tensor: The reduced tensor. - """ + """Reduces a tensor across all devices.""" if not _is_distributed(): world_size = ( torch.cuda.device_count() if torch.cuda.is_available() else 1 @@ -152,17 +144,10 @@ def all_reduce( dist.all_reduce(result, op=reduce_op) return result - def all_gather(x: torch.Tensor, axis: int = 0) -> torch.Tensor: - """Gathers tensors from all devices and concatenates them. - - Args: - x (torch.Tensor): The local tensor to gather. - axis (int, optional): The axis along which to concatenate. - Defaults to 0. - - Returns: - torch.Tensor: The concatenated tensor from all devices. - """ + def all_gather( + x: torch.Tensor, axis: int = 0, axis_name: str = None + ) -> torch.Tensor: + """Gathers tensors from all devices and concatenates them.""" if not _is_distributed(): world_size = ( torch.cuda.device_count() if torch.cuda.is_available() else 1 @@ -176,20 +161,13 @@ def all_gather(x: torch.Tensor, axis: int = 0) -> torch.Tensor: dist.all_gather(tensor_list, x) return torch.cat(tensor_list, dim=axis) - def broadcast(x: torch.Tensor, root: int = 0) -> torch.Tensor: - """Broadcasts a tensor from a root device to all other devices. - - Args: - x (torch.Tensor): The tensor to broadcast. - root (int, optional): The rank of the source device. Defaults to 0. - - Returns: - torch.Tensor: The tensor received from the root device. - """ + def broadcast( + x: torch.Tensor, root: int = 0, axis_name: str = None + ) -> torch.Tensor: + """Broadcasts a tensor from a root device to all other devices.""" if not _is_distributed(): return x - # `dist.broadcast` is in-place. dist.broadcast(x, src=root) return x @@ -197,24 +175,9 @@ def scatter( x: torch.Tensor, root: int = 0, axis: int = 0, + axis_name: str = None, ) -> torch.Tensor: - """Scatters a tensor from a root device to all devices. - - Note: The current implementation of `dist.scatter` requires the input - tensor `x` to be organized differently for the root process. This - wrapper simplifies it by handling the splitting automatically on the - root process. - - Args: - x (torch.Tensor): The tensor on the root device to be scattered. - root (int, optional): The rank of the device holding the tensor. - Defaults to 0. - axis (int, optional): The axis along which to split the tensor. - Defaults to 0. - - Returns: - torch.Tensor: The chunk of the tensor for the local device. - """ + """Scatters a tensor from a root device to all devices.""" if not _is_distributed(): world_size = ( torch.cuda.device_count() if torch.cuda.is_available() else 1 diff --git a/keras/src/backend/torch/distributed_backend_test.py b/keras/src/backend/torch/distributed_backend_test.py index 9aaf68f8a9f9..d6dce5977b81 100644 --- a/keras/src/backend/torch/distributed_backend_test.py +++ b/keras/src/backend/torch/distributed_backend_test.py @@ -1,7 +1,3 @@ -import os - -os.environ["JAX_PLATFORM_NAME"] = "cpu" - import pytest import torch diff --git a/keras/src/backend/torch/distribution_lib.py b/keras/src/backend/torch/distribution_lib.py new file mode 100644 index 000000000000..0d8c18de4bf7 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib.py @@ -0,0 +1,413 @@ +"""Utilities for distribution strategy with Torch backend. + +This file contains the core Torch distribution primitives from Keras, +along with higher-level device management and auto-configuration utilities. +This version does not use try-except blocks for error handling. +""" + +import logging +import os +from typing import Dict +from typing import List +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist + +from keras.src.backend.common import global_state +from keras.src.random import seed_generator +from keras.src.utils import rng_utils + +logger = logging.getLogger(__name__) + + +def list_devices(device_type=None): + """Return all the available devices based on the device type. + + Note that this should return the global devices in a distributed setting. + + Args: + device_type: string of `"cpu"`, `"gpu"`. Defaults to `"gpu"` if + available when device_type is not provided. Otherwise will return + the `"cpu"` devices. `"tpu"` is not supported by the default + torch backend. + + Return: + List of devices that are available for distribute computation. + """ + if device_type: + device_type = device_type.lower() + else: + device_type = "cuda" if torch.cuda.is_available() else "cpu" + + if device_type in ("gpu", "cuda"): + if not torch.cuda.is_available(): + return [] + return [f"cuda:{i}" for i in range(torch.cuda.device_count())] + elif device_type == "cpu": + return ["cpu:0"] + elif device_type == "tpu": + logger.warning( + "TPU device type is not supported by the default " + "PyTorch backend. Use the `torch_xla` package." + ) + return [] + raise ValueError(f"Unknown device type: {device_type}") + + +def get_device_info(device_id: str) -> Dict[str, any]: + """ + Get detailed information about a specific device. + + Args: + device_id: Device identifier (e.g., 'cuda:0', 'cpu:0') + + Returns: + Dictionary containing device information + """ + device_info = { + "id": device_id, + "type": None, + "index": None, + "memory": None, + "capabilities": None, + } + + device_type, device_index = device_id.split(":") + device_type_map = {"cuda": "GPU", "cpu": "CPU"} + device_info["type"] = device_type_map.get(device_type, device_type.upper()) + device_info["index"] = int(device_index) + + return device_info + + +def get_best_devices(count: int = 1) -> List[str]: + """ + Get the best available devices for tensor parallelism. + + Args: + count: Number of devices needed + + Returns: + List of best device identifiers + """ + all_devices = list_devices("cuda") + if not all_devices: + all_devices = list_devices("cpu") + + if count <= 0: + return [] + + if count > len(all_devices): + logger.warning( + f"Requested {count} devices but only {len(all_devices)} available" + ) + count = len(all_devices) + + return all_devices[:count] + + +def get_device_backend(device_type: str) -> str: + """ + Get the recommended backend for a device type. + + Args: + device_type: Device type ('tpu', 'gpu', 'cpu') + + Returns: + Recommended backend name + """ + backend_mapping = {"gpu": "torch", "cuda": "torch", "cpu": "torch"} + + return backend_mapping.get(device_type.lower(), "torch") + + +def validate_device_placement(device_id: str) -> bool: + """ + Validate if a device can be used for tensor operations. + + Args: + device_id: Device identifier + + Returns: + True if device is valid and available + """ + if ":" not in device_id: + return False + + device_type = device_id.split(":")[0] + known_device_types = ("cpu", "gpu", "cuda", "tpu") + if device_type not in known_device_types: + return False + + all_devices = list_devices(device_type) + return device_id in all_devices + + +def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]: + """ + Get memory information for a device (if available). + + Args: + device_id: Device identifier + + Returns: + Memory information dictionary or None if not available + """ + if device_id.startswith("cuda:"): + return { + "type": "GPU", + "index": int(device_id.split(":")[1]), + "memory": "Available", + } + elif device_id.startswith("cpu:"): + return { + "type": "CPU", + "index": int(device_id.split(":")[1]), + "memory": "System RAM", + } + + return None + + +def auto_configure_tensor_parallel( + world_size: int = None, backend: str = None +) -> Dict[str, any]: + """ + Automatically configure tensor parallelism with the best available devices. + + Args: + world_size: Number of devices to use (if None, uses all available GPUs) + backend: Backend to use (if None, will be set to 'torch') + + Returns: + Configuration dictionary with devices, backend, and other settings + """ + all_devices = list_devices() + + if not all_devices: + raise RuntimeError("No devices available for tensor parallelism") + + if world_size is None: + world_size = len(all_devices) + else: + world_size = min(world_size, len(all_devices)) + + selected_devices = all_devices[:world_size] + + recommended_backend = "torch" + + config = { + "devices": selected_devices, + "world_size": world_size, + "backend": recommended_backend, + } + + logger.info(f"Auto-configured tensor parallelism: {config}") + return config + + +def distribute_variable(value, layout): + """Create a distributed variable for PyTorch. + + This function creates a `torch.Tensor` distributed according to the given + layout. In PyTorch, variables and tensors are unified in the `Tensor` class. + + Args: + value: The initial value of the variable as a `torch.Tensor`. + layout: `TensorLayout` for the created variable, or a PyTorch-supported + layout instance (e.g., a list of `Placement` types). + + Returns: + `torch.Tensor` which is the distributed variable. + """ + return distribute_tensor(value, layout) + + +def distribute_tensor(tensor, layout): + """Distribute the tensor based on the layout. + + Args: + tensor: `torch.Tensor` that needs to be distributed. + layout: `TensorLayout` for the created variable, or a PyTorch-supported + layout instance (e.g., a list of `Placement` types). + + Returns: + Distributed `torch.Tensor`. + """ + # Avoid circular imports. + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + placements = layout.backend_layout + device_mesh = layout.device_mesh.backend_mesh + else: + raise ValueError( + "Directly passing backend layout is not yet supported for torch. " + "Please provide a `keras.distribution.TensorLayout` instance." + ) + + return dist.dtensor.distribute_tensor( + tensor.to("cpu"), device_mesh, placements + ) + + +def distribute_data_input(per_process_batch, layout, batch_dim_name): + """Distribute the input data with the corresponding layout. + + Note that the input here is a local worker batch. PyTorch's `from_local` + is used to construct a global DTensor from these local shards. + + Args: + per_process_batch: `torch.Tensor` that is local shard for this process. + layout: `TensorLayout` for the distribution information. + + Returns: + A global batch distributed according to `layout`. + """ + from keras.src.distribution import TensorLayout + + if not isinstance(layout, TensorLayout): + raise ValueError( + "A `keras.distribution.TensorLayout` instance is required." + ) + + placements = layout.backend_layout + device_mesh = layout.device_mesh.backend_mesh + return dist.dtensor.from_local( + per_process_batch, device_mesh, placements, run_check=True + ) + + +def initialize_rng(): + """Initializes the global random number generator across processes. + + This is required for consistent initialization in multi-host settings. + It works by generating a seed on rank 0 and broadcasting it to all other + processes. + """ + global_seed = rng_utils.get_random_seed() + if global_seed is None: + if not dist.is_initialized(): + seed = seed_generator.make_default_seed() + else: + if process_id() == 0: + seed = seed_generator.make_default_seed() + seed_tensor = torch.tensor( + seed, dtype=torch.int64, device="cpu" + ) + else: + seed_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + dist.broadcast(seed_tensor, src=0) + seed = seed_tensor.item() + global_seed = seed + rng_utils.set_random_seed(global_seed) + + global_seed_generator = global_state.get_global_attribute( + "global_seed_generator" + ) + if global_seed_generator is not None and global_seed_generator.seed is None: + global_state.set_global_attribute( + "global_seed_generator", + seed_generator.SeedGenerator( + seed=global_seed, + name=global_seed_generator.name, + backend=global_seed_generator.backend, + ), + ) + + +def initialize(job_addresses, num_processes, process_id): + """Initializes the distributed process group in PyTorch.""" + os.environ["RANK"] = str(process_id) + os.environ["WORLD_SIZE"] = str(num_processes) + + if "," in job_addresses: + master_addr = job_addresses.split(",")[0] + else: + master_addr = job_addresses + + if ":" not in master_addr: + raise ValueError( + "Invalid `job_addresses`. Expected format `hostname:port`, " + f"but got {master_addr}" + ) + + master_host, master_port = master_addr.split(":") + os.environ["MASTER_ADDR"] = master_host + os.environ["MASTER_PORT"] = master_port + + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend) + + initialize_rng() + + +def num_processes(): + """Return the number of processes for the current distribution setting.""" + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def process_id(): + """Return the current process ID for the distribution setting.""" + if dist.is_initialized(): + return dist.get_rank() + return 0 + + +def _to_backend_device(device_name): + if isinstance(device_name, torch.device): + return device_name + return torch.device(device_name) + + +def _to_backend_mesh(device_mesh): + """Convert the DeviceMesh to Torch backend specific Mesh. + + Args: + device_mesh: DeviceMesh instance to convert. + + Returns: + A `torch.distributed.DeviceMesh` instance. + """ + mesh_shape = device_mesh.devices.shape + mesh_devices = np.array(device_mesh.devices.flatten()).reshape(mesh_shape) + return dist.DeviceMesh( + device_type="cuda" if torch.cuda.is_available() else "cpu", + mesh=mesh_devices, + ) + + +def _to_backend_layout(tensor_layout): + """Convert the TensorLayout to Torch backend specific placement. + + Args: + tensor_layout: TensorLayout instance to convert. + + Returns: + A list of `torch.distributed.placement_types.Placement` instances. + """ + if tensor_layout.device_mesh is None: + raise ValueError( + "Cannot create sharding when device mesh is not set " + "for TensorLayout." + ) + + mesh_axes = tensor_layout.device_mesh.axis_names + placements = [] + for axis in tensor_layout.axes: + if axis is None: + placements.append(dist.Replicate()) + else: + try: + mesh_dim = mesh_axes.index(axis) + placements.append(dist.Shard(mesh_dim)) + except ValueError: + raise ValueError( + f"Tensor axis `{axis}` is not found in the " + f"device mesh axes `{mesh_axes}`." + ) from None + return placements diff --git a/keras/src/backend/torch/distribution_lib_test.py b/keras/src/backend/torch/distribution_lib_test.py new file mode 100644 index 000000000000..2897b022a0d4 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib_test.py @@ -0,0 +1,160 @@ +import os + +import numpy as np +import pytest +import torch +import torch.distributed as dist + +from keras.src import backend +from keras.src.backend import distribution_lib +from keras.src.distribution import DeviceMesh +from keras.src.distribution import TensorLayout + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Backend specific test", +) +def setup_torch_distributed(): + """ + A fixture to initialize the distributed process group if not already done. + This allows test file to be run directly with `pytest` for single-process + checks, while also working correctly when launched with `torchrun`. + """ + if not dist.is_available() or dist.is_initialized(): + return + + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + dist.init_process_group(backend="gloo") + + +@pytest.mark.skipif( + not torch.distributed.is_available(), + reason="PyTorch distributed components are not available.", +) +class TestTorchDistributionLibLive: + """ + Tests for the Torch distribution library without using mocks. + These tests will reflect the capabilities of environment they are run in. + """ + + def test_device_listing_and_info(self): + """Tests device discovery functions against the runtime environment.""" + if torch.cuda.is_available(): + gpu_devices = distribution_lib.list_devices("gpu") + assert len(gpu_devices) == torch.cuda.device_count() + assert gpu_devices[0] == "cuda:0" + else: + assert distribution_lib.list_devices("gpu") == [] + + cpu_devices = distribution_lib.list_devices("cpu") + assert cpu_devices == ["cpu:0"] + + with pytest.raises(ValueError, match="Unknown device type"): + distribution_lib.list_devices("unsupported_device") + + def test_device_helpers(self): + """Tests validation, backend, and memory info functions.""" + device_str = "cpu:0" + if torch.cuda.is_available(): + device_str = "cuda:0" + + assert distribution_lib.validate_device_placement(device_str) is True + assert distribution_lib.validate_device_placement("invalid:0") is False + + assert distribution_lib.get_device_backend("cpu") == "torch" + assert distribution_lib.get_device_backend("gpu") == "torch" + + mem_info = distribution_lib.get_device_memory_info(device_str) + assert mem_info is not None + assert "type" in mem_info + assert mem_info["index"] == 0 + + def test_process_discovery(self): + """Tests process_id and num_processes in the live environment.""" + rank = distribution_lib.process_id() + world_size = distribution_lib.num_processes() + + if dist.is_initialized(): + assert rank == dist.get_rank() + assert world_size == dist.get_world_size() + else: + assert rank == 0 + assert world_size == 1 + + def test_backend_conversions(self): + """Tests the conversion of Keras objects to Torch backend objects.""" + world_size = distribution_lib.num_processes() + if world_size < 2: + pytest.skip( + "Skipping conversion tests in a single-process environment." + ) + + devices = [f"cpu:{i}" for i in range(world_size)] + shape = (world_size,) + axis_names = ("data",) + keras_mesh = DeviceMesh(shape, axis_names, devices) + + torch_mesh = distribution_lib._to_backend_mesh(keras_mesh) + assert isinstance(torch_mesh, dist.DeviceMesh) + assert torch_mesh.mesh.shape == shape + + keras_layout = TensorLayout(axes=("data",), device_mesh=keras_mesh) + placements = distribution_lib._to_backend_layout(keras_layout) + assert isinstance(placements[0], dist.Shard) + + keras_layout_replicated = TensorLayout( + axes=(None,), device_mesh=keras_mesh + ) + placements_replicated = distribution_lib._to_backend_layout( + keras_layout_replicated + ) + assert isinstance(placements_replicated[0], dist.Replicate) + + def test_tensor_distribution(self): + """Tests the distribution of a tensor into a DTensor.""" + if not dist.is_initialized() or distribution_lib.num_processes() < 2: + pytest.skip( + "Tensor distribution test requires a multi-process environment." + ) + + world_size = distribution_lib.num_processes() + devices = np.arange(world_size) + keras_mesh = DeviceMesh((world_size,), ("batch",), devices) + keras_layout = TensorLayout(("batch", None), keras_mesh) + + local_tensor = torch.randn((10, 20)) + + dtensor = distribution_lib.distribute_tensor(local_tensor, keras_layout) + assert isinstance(dtensor, torch.distributed.dtensor.DTensor) + assert dtensor.device_mesh.mesh.shape == (world_size,) + assert isinstance(dtensor.placements[0], dist.Shard) + + dvariable = distribution_lib.distribute_variable( + local_tensor, keras_layout + ) + assert isinstance(dvariable, torch.distributed.dtensor.DTensor) + + def test_distribute_data_input(self): + """Tests the `from_local` logic for distributing input data.""" + if not dist.is_initialized() or distribution_lib.num_processes() < 2: + pytest.skip( + "Input distribution test requires a multi-process environment." + ) + + world_size = distribution_lib.num_processes() + devices = np.arange(world_size) + keras_mesh = DeviceMesh((world_size,), ("batch",), devices) + keras_layout = TensorLayout(("batch", None), keras_mesh) + + per_process_batch = torch.ones((8, 16)) + + global_batch = distribution_lib.distribute_data_input( + per_process_batch, keras_layout, batch_dim_name="batch" + ) + + assert isinstance(global_batch, torch.distributed.dtensor.DTensor) + assert global_batch.shape == (world_size * 8, 16) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..470c00774660 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,245 @@ +import os + +import pytest + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" + +from keras import Input +from keras import Model +from keras import layers +from keras.src import backend +from keras.src import testing +from keras.src.distribution import distributed_backend +from keras.src.distribution.tensor_parallel.autoconfig import ( + analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax") + or distributed_backend.get_device_info()["device_count"] <= 1, + reason="This test is for JAX/PyTorch backends and requires > 1 device.", +) +class TestAutoConfigKeras(testing.TestCase): + def setUp(self): + """Set up the test case and common variables.""" + super().setUp() + device_info = distributed_backend.get_device_info() + self.world_size = device_info["device_count"] + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] + + self.assertGreater( + self.world_size, 1, "Distribution tests require more than 1 device." + ) + + def _assert_split_keras_equal(self, rule1, rule2): + """Helper to compare two SplitKeras objects by their attributes.""" + self.assertIsInstance(rule1, SplitKeras) + self.assertIsInstance(rule2, SplitKeras) + self.assertDictEqual(vars(rule1), vars(rule2)) + + def _assert_rules_equal(self, actual_rules, expected_rules): + """Helper to compare two dictionaries of sharding rules.""" + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) + for key in expected_rules: + actual_val = actual_rules[key] + expected_val = expected_rules[key] + if isinstance(expected_val, SplitKeras): + self._assert_split_keras_equal(actual_val, expected_val) + else: + self.assertEqual(actual_val, expected_val) + + def test_analyze_dense_layer(self): + """Tests the direct analysis and classification of Dense layers.""" + up_proj_layer = layers.Dense(32) + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", + ) + + down_proj_layer = layers.Dense(16) + down_proj_layer.build(input_shape=(None, 32)) + self.assertEqual( + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", + ) + + generic_layer = layers.Dense(20) + generic_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", + ) + + def test_simple_mlp_sharding(self): + """Tests a simple MLP with up and down projection layers.""" + inputs = Input(shape=(64,)) + x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) + model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^simple_mlp.up_projection_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + # Bias for down-projection is not sharded according to the new logic + } + expected_output_rules = { + r"^simple_mlp.up_projection_layer$": {0: "gather"}, + r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_generic_dense_sharding(self): + """Tests a generic Dense layer that isn't an up/down projection.""" + inputs = Input(shape=(64,)) + outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) + model = Model(inputs=inputs, outputs=outputs, name="generic_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^generic_model.generic_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^generic_model.generic_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + } + expected_output_rules = { + r"^generic_model.generic_layer$": {0: "gather -1"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_embedding_sharding(self): + """Tests an Embedding layer for vocabulary parallelism.""" + inputs = Input(shape=(10,), dtype="int32") + outputs = layers.Embedding( + input_dim=1000, output_dim=128, name="token_embedding" + )(inputs) + model = Model(inputs=inputs, outputs=outputs, name="embed_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + # FIX: Removed the incorrect backslash before ".token_embedding" + r"^embed_model.token_embedding\..*embeddings$": SplitKeras( + self.world_size, 1, "column" + ) + } + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_einsum_dense_sharding(self): + """Tests the special handling for EinsumDense layers.""" + inputs = Input(shape=(64,)) + x = layers.EinsumDense( + "bh,hd->bd", output_shape=128, name="query_proj" + )(inputs) + outputs = layers.EinsumDense( + "bd,dh->bh", output_shape=64, name="attention_output" + )(x) + model = Model(inputs=inputs, outputs=outputs, name="einsum_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^einsum_model.query_proj.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^einsum_model.attention_output.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^einsum_model.query_proj$": {0: "gather -1"}, + r"^einsum_model.attention_output$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_normalization_layers_ignored(self): + """Tests that normalization layers are correctly ignored.""" + inputs = Input(shape=(64,)) + x = layers.Dense(64, name="dense1", use_bias=True)(inputs) + x = layers.LayerNormalization(name="layernorm")(x) + outputs = layers.Dense(64, name="dense2", use_bias=True)(x) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") + + config = get_default_config_keras(model, self.device_ids) + + for key in config.state_rules: + self.assertNotIn("layernorm", key) + for key in config.output_rules: + self.assertNotIn("layernorm", key) + + self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) + self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) + self.assertEqual(len(config.state_rules), 4) + self.assertEqual(len(config.output_rules), 2) + + def test_nested_model_sharding(self): + """Tests that the traversal logic correctly handles nested models.""" + inner_inputs = Input(shape=(32,)) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) + inner_model = Model( + inputs=inner_inputs, outputs=inner_outputs, name="inner_block" + ) + + outer_inputs = Input(shape=(32,)) + x = inner_model(outer_inputs) + outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) + outer_model = Model( + inputs=outer_inputs, outputs=outer_outputs, name="outer_model" + ) + + config = get_default_config_keras(outer_model, self.device_ids) + + expected_state_rules = { + r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^outer_model.outer_dense.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + # Bias for down-projection is not sharded according to the new logic + } + expected_output_rules = { + r"^outer_model.inner_block.inner_dense$": {0: "gather"}, + r"^outer_model.outer_dense$": {0: "allreduce"}, + } + + self.maxDiff = None + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/communications_test.py b/keras/src/distribution/tensor_parallel/communications_test.py new file mode 100644 index 000000000000..1b6d95d37664 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/communications_test.py @@ -0,0 +1,84 @@ +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.backend import distributed_backend +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras +from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, +) + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class TestCollectiveOps(testing.TestCase): + """ + Tests collective communication ops on a JAX distributed backend. + """ + + def setUp(self): + super().setUp() + device_info = distributed_backend.get_device_info() + self.world_size = device_info.get("device_count", 1) + + if not self.world_size: + self.world_size = 1 + + self.axis_name = "data" + + def test_all_reduce(self): + """Tests the all-reduce operation.""" + all_reduce_op = AllReduceKeras(world_size=self.world_size, op="sum") + local_tensor = keras.ops.array([1.0, 2.0, 3.0]) + + result = all_reduce_op(local_tensor, axis_name=self.axis_name) + + expected_output = keras.ops.multiply( + local_tensor, float(self.world_size) + ) + self.assertAllClose(result, expected_output) + + def test_all_gather(self): + """Tests the all-gather operation.""" + all_gather_op = AllGatherKeras(world_size=self.world_size, dim=0) + local_slice = keras.ops.arange(6, dtype="float32").reshape((2, 3)) + result = all_gather_op(local_slice, axis_name=self.axis_name) + + expected_output = keras.ops.concatenate( + [local_slice] * self.world_size, axis=0 + ) + self.assertAllClose(result, expected_output) + + def test_broadcast(self): + """Tests the broadcast operation.""" + broadcast_op = BroadcastKeras( + world_size=self.world_size, src_rank=0, rank=0 + ) + tensor_to_broadcast = keras.ops.array([5.0, 10.0, 15.0]) + result = broadcast_op(tensor_to_broadcast, axis_name=self.axis_name) + + self.assertAllClose(result, tensor_to_broadcast) + + def test_tensor_parallel_communicator_forward_column_parallel(self): + """Tests the communicator's all-gather for column-parallel forward.""" + communicator = TensorParallelCommunicator( + world_size=self.world_size, rank=0 + ) + + local_slice = keras.ops.array([[0.0, 1.0], [2.0, 3.0]], dtype="float32") + + result = communicator.forward_column_parallel( + partial_outputs=[local_slice], + dim=0, + axis_name=self.axis_name, + ) + + expected_output = keras.ops.concatenate( + [local_slice] * self.world_size, axis=0 + ) + self.assertAllClose(result, expected_output) diff --git a/keras/src/distribution/tensor_parallel/config_test.py b/keras/src/distribution/tensor_parallel/config_test.py new file mode 100644 index 000000000000..cbb26e40e6db --- /dev/null +++ b/keras/src/distribution/tensor_parallel/config_test.py @@ -0,0 +1,96 @@ +import pytest + +from keras.src import backend +from keras.src import testing +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.config import _create_ops_from_rules + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class TestConfig(testing.TestCase): + """Test suite for the tensor parallel configuration.""" + + def test_create_ops_from_rules_helper(self): + """ + Tests the private _create_ops_from_rules helper function directly + to ensure it correctly parses various rule types. + """ + devices = ["/gpu:0", "/gpu:1"] + world_size = len(devices) + rules = { + "dense/kernel": {"forward": "sum", "backward": "mean"}, + "embedding/weight": { + "forward": "gather 0", + "backward": "gather -1", + }, + "attention/dense/bias": {"forward": "broadcast"}, + "passthrough": {"action": 123}, + "no_dict_action": "identity", + } + + processed_rules = _create_ops_from_rules(rules, world_size) + + sum_op = processed_rules["dense/kernel"]["forward"] + self.assertIsInstance(sum_op, AllReduceKeras) + self.assertEqual(sum_op.op, "sum") + self.assertEqual(sum_op.world_size, world_size) + + mean_op = processed_rules["dense/kernel"]["backward"] + self.assertIsInstance(mean_op, AllReduceKeras) + self.assertEqual(mean_op.op, "mean") + + gather_op_0 = processed_rules["embedding/weight"]["forward"] + self.assertIsInstance(gather_op_0, AllGatherKeras) + self.assertEqual(gather_op_0.dim, 0) + self.assertEqual(gather_op_0.world_size, world_size) + + gather_op_neg1 = processed_rules["embedding/weight"]["backward"] + self.assertIsInstance(gather_op_neg1, AllGatherKeras) + self.assertEqual(gather_op_neg1.dim, -1) + + broadcast_op = processed_rules["attention/dense/bias"]["forward"] + self.assertIsInstance(broadcast_op, BroadcastKeras) + self.assertEqual(broadcast_op.world_size, world_size) + + self.assertEqual(processed_rules["passthrough"]["action"], 123) + self.assertEqual(processed_rules["no_dict_action"], "identity") + + def test_config_keras_create_collective_ops(self): + """ + Tests the public create_collective_ops method of the ConfigKeras class. + """ + devices = ["/gpu:0", "/gpu:1"] + world_size = len(devices) + + state_rules = {"some_weight": "split"} + output_rules = { + "layer_1_output": {"activation": "sum"}, + "layer_2_output": {"activation": "gather -1"}, + } + + config = ConfigKeras(state_rules=state_rules, output_rules=output_rules) + new_config = config.create_collective_ops(devices) + + self.assertIsNot(new_config, config) + + self.assertEqual(new_config.state_rules, state_rules) + + self.assertIsInstance( + config.output_rules["layer_1_output"]["activation"], str + ) + + sum_op = new_config.output_rules["layer_1_output"]["activation"] + self.assertIsInstance(sum_op, AllReduceKeras) + self.assertEqual(sum_op.op, "sum") + self.assertEqual(sum_op.world_size, world_size) + + gather_op = new_config.output_rules["layer_2_output"]["activation"] + self.assertIsInstance(gather_op, AllGatherKeras) + self.assertEqual(gather_op.dim, -1) + self.assertEqual(gather_op.world_size, world_size) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..38d80a5ec258 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,180 @@ +import numpy as np +import pytest + +import keras +from keras import ops +from keras.src import backend +from keras.src import optimizers +from keras.src import testing +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, world_size): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(world_size): + multiplier = float(i + 1) + gradients = [ + ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer(base_optimizer, world_size=4) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + super().apply_gradients(grads_and_vars, *args, **kwargs) + + world_size = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord = CoordinatedOptimizer( + optimizer, + world_size, + shard_optimizer_states=False, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + grad_numpy = ops.convert_to_numpy(optimizer.received_grads[0]) + self.assertAllClose( + grad_numpy, + np.ones_like(grad_numpy) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer("adam", world_size=4) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + world_size = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer(base_opt, world_size) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord_apply_tracker = {"called": False} + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + + base_apply_tracker = {"called": False} + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=4) + model = self._get_simple_model() + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + self.assertIn("momentum", sharded_states) + self.assertIn("velocity", sharded_states) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) + self.assertEqual( + len(sharded_states["momentum"][dense_1_kernel_path]), 4 + ) + + def test_serialization(self): + world_size = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.world_size, world_size) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertIsNone(recreated.distributed_backend) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=2) + optimizer.build(model.trainable_variables) + + state_to_param = ( + optimizer.coordinated_optimizer._state_variable_to_parameter + ) + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + + self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py similarity index 68% rename from keras/src/distribution/tensor_parallel/parameter_shardimg_test.py rename to keras/src/distribution/tensor_parallel/parameter_sharding_test.py index 681d6724c325..dc686436af97 100644 --- a/keras/src/distribution/tensor_parallel/parameter_shardimg_test.py +++ b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py @@ -8,6 +8,7 @@ import keras from keras import distribution +from keras.src import backend from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.parameter_sharding import ( ShardedWeight, @@ -20,19 +21,11 @@ @pytest.mark.skipif( - keras.backend.backend() != "jax", - reason="This test is JAX-specific.", + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", ) def _create_simple_mlp(): - """Creates a simple, unsharded Keras MLP model for testing. - - This model serves as the baseline for sharding tests. It consists of - an input layer, a hidden dense layer with a ReLU activation, and an - output dense layer. - - Returns: - A `keras.Model` instance. - """ + """Creates a simple, unsharded Keras MLP model for testing.""" inputs = keras.Input(shape=(16,), name="input") x = keras.layers.Dense(32, use_bias=True, name="up_proj")(inputs) x = keras.layers.Activation("relu")(x) @@ -41,58 +34,39 @@ def _create_simple_mlp(): class ParameterShardingTest(TestCase): - """Test suite for parameter sharding functionality. - - This class tests the creation of sharded models, the correctness of - sharded weight shapes, and the numerical accuracy of the forward pass - of a sharded model compared to its original, unsharded counterpart. - """ - def setUp(self): - """Sets up the testing environment before each test case.""" super().setUp() + import logging + + logging.getLogger().setLevel(logging.ERROR) self.world_size = 2 all_devices = distribution.list_devices() self.devices = all_devices[: self.world_size] if len(self.devices) < self.world_size: self.skipTest( - f"""Not enough devices to run TP test. + f"""Not enough devices to run TP test. Found {len(self.devices)}, need {self.world_size}""" ) - # Create the original model and the sharding configuration. self.original_model = _create_simple_mlp() self.original_model.build(input_shape=(None, 16)) self.tp_config = ConfigKeras( state_rules={ - # Rule to split the first dense layer's kernel along the output - # dimension (column-wise). re.escape("simple_mlp.up_proj.kernel"): SplitKeras( self.world_size, dim=1 ), - # Rule to split the second dense layer's kernel along the input - # dimension (row-wise). re.escape("simple_mlp.down_proj.kernel"): SplitKeras( self.world_size, dim=0 ), }, output_rules={}, ) - # Generate dummy data for testing forward passes. self.input_data = np.random.rand(4, 16).astype("float32") self.labels = np.random.rand(4, 8).astype("float32") def test_model_sharding_creation_and_weight_counts(self): - """Tests if sharded models are created correctly. - - Verifies that: - 1. `make_parameter_sharded_model` returns a valid Keras model. - 2. The set of modified parameters correctly identifies sharded layers. - 3. The total number of weights in the sharded model matches the - original model, ensuring no weights are lost or added. - """ sharded_models = [] for rank in range(self.world_size): with keras.device(self.devices[rank]): @@ -107,19 +81,11 @@ def test_model_sharding_creation_and_weight_counts(self): self.assertIn("simple_mlp.up_proj.kernel", modified_params) self.assertIn("simple_mlp.down_proj.kernel", modified_params) sharded_models.append(sharded_model) - - # The sharded model should have the same number of weight objects. self.assertEqual( len(self.original_model.weights), len(sharded_models[0].weights) ) def test_sharded_weight_shapes(self): - """Validates the shapes of the weights after sharding. - - This test ensures that the dimensions specified in the sharding rules - are correctly split by the world size, while other dimensions remain - unchanged. - """ rank = 0 with keras.device(self.devices[rank]): sharded_model, _ = make_parameter_sharded_model( @@ -129,14 +95,11 @@ def test_sharded_weight_shapes(self): world_size=self.world_size, device_id=self.devices[rank], ) - original_weights_dict = {w.path: w for w in self.original_model.weights} sharded_weights_dict = { w.name if isinstance(w, ShardedWeight) else w.path: w for w in sharded_model.weights } - - # Check the shape of the column-split kernel. orig_up_kernel = original_weights_dict["up_proj/kernel"] shard_up_kernel = sharded_weights_dict["simple_mlp.up_proj.kernel"] self.assertEqual(shard_up_kernel.shape[0], orig_up_kernel.shape[0]) @@ -144,8 +107,6 @@ def test_sharded_weight_shapes(self): shard_up_kernel.shape[1], orig_up_kernel.shape[1] // self.world_size, ) - - # Check the shape of the row-split kernel. orig_down_kernel = original_weights_dict["down_proj/kernel"] shard_down_kernel = sharded_weights_dict["simple_mlp.down_proj.kernel"] self.assertEqual( @@ -155,22 +116,13 @@ def test_sharded_weight_shapes(self): self.assertEqual(shard_down_kernel.shape[1], orig_down_kernel.shape[1]) def test_forward_pass_correctness(self): - """Checks if the sharded model's output matches the original. - - This test performs a forward pass on both the original model and the - sharded models. It then reconstructs the output from the sharded - models and asserts that it is numerically close to the original - model's output. This serves as an end-to-end correctness check. - """ expected_output = self.original_model(self.input_data) sharded_outputs = [] original_weights = self.original_model.get_weights() - for rank in range(self.world_size): with keras.device(self.devices[rank]): cloned_original = keras.models.clone_model(self.original_model) cloned_original.set_weights(original_weights) - sharded_model, _ = make_parameter_sharded_model( cloned_original, self.tp_config, @@ -180,7 +132,6 @@ def test_forward_pass_correctness(self): ) output = sharded_model(self.input_data) sharded_outputs.append(output) - reconstructed_output = ( keras.ops.sum(keras.ops.stack(sharded_outputs), axis=0) / self.world_size diff --git a/keras/src/distribution/tensor_parallel/state_action_keras_test.py b/keras/src/distribution/tensor_parallel/state_action_keras_test.py new file mode 100644 index 000000000000..a6947958a4aa --- /dev/null +++ b/keras/src/distribution/tensor_parallel/state_action_keras_test.py @@ -0,0 +1,109 @@ +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.distribution.tensor_parallel.state_action_keras import ( + GatherKeras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.distribution.tensor_parallel.state_action_keras import SumKeras + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class TestStateActions(testing.TestCase): + """Test suite for tensor distribution state actions.""" + + def test_split_keras_even_split(self): + """Tests SplitKeras with a tensor that divides evenly.""" + world_size = 4 + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (4, 4) + ) + + action_row = SplitKeras(world_size=world_size, dim=0) + shards_row = [action_row(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards_row[0].shape, (1, 4)) + self.assertAllClose(shards_row[0], tensor[0:1, :]) + self.assertAllClose(shards_row[3], tensor[3:4, :]) + + reconstructed_row = action_row.undo(shards_row) + self.assertAllClose(reconstructed_row, tensor) + + action_col = SplitKeras(world_size=world_size, dim=1) + shards_col = [action_col(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards_col[0].shape, (4, 1)) + self.assertAllClose(shards_col[0], tensor[:, 0:1]) + self.assertAllClose(shards_col[2], tensor[:, 2:3]) + + reconstructed_col = action_col.undo(shards_col) + self.assertAllClose(reconstructed_col, tensor) + + def test_split_keras_uneven_split(self): + """Tests SplitKeras with a tensor that does not divide evenly.""" + world_size = 3 + tensor = keras.ops.reshape( + keras.ops.arange(40, dtype="float32"), (4, 10) + ) + + action = SplitKeras(world_size=world_size, dim=1) + shards = [action(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards[0].shape, (4, 4)) + self.assertEqual(shards[1].shape, (4, 3)) + self.assertEqual(shards[2].shape, (4, 3)) + + self.assertAllClose(shards[0], tensor[:, 0:4]) + self.assertAllClose(shards[1], tensor[:, 4:7]) + self.assertAllClose(shards[2], tensor[:, 7:10]) + + reconstructed = action.undo(shards) + self.assertAllClose(reconstructed, tensor) + + def test_split_keras_sharding_type_inference(self): + """Tests that `sharding_type` correctly infers the split dimension.""" + action_row = SplitKeras(world_size=2, dim=-1, sharding_type="row") + self.assertEqual(action_row.dim, 0) + + action_col = SplitKeras(world_size=2, dim=-1, sharding_type="column") + self.assertEqual(action_col.dim, 1) + + def test_gather_keras(self): + """Tests the GatherKeras action.""" + world_size = 4 + action = GatherKeras(world_size=world_size, dim=0) + tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") + + processed_tensor = action(tensor, rank=0) + self.assertAllClose(processed_tensor, tensor) + + tensors_to_gather = [ + keras.ops.ones((2, 2)), + keras.ops.zeros((2, 2)), + keras.ops.ones((2, 2)), + ] + reconstructed = action.undo(tensors_to_gather) + expected = keras.ops.concatenate(tensors_to_gather, axis=0) + self.assertAllClose(reconstructed, expected) + + def test_sum_keras(self): + """Tests the SumKeras action.""" + world_size = 2 + action = SumKeras(world_size=world_size) + tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") + + processed_tensor = action(tensor, rank=0) + self.assertAllClose(processed_tensor, tensor) + + tensors_to_sum = [ + keras.ops.full((2, 3), 5.0), + keras.ops.full((2, 3), 10.0), + ] + reconstructed = action.undo(tensors_to_sum) + expected = keras.ops.full((2, 3), 15.0) + self.assertAllClose(reconstructed, expected) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 11e4046c7b8a..ba4abbe1139a 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -39,7 +39,8 @@ from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.config import is_nnx_enabled -from keras.src.distribution import distribution_lib + +# from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec from keras.src.metrics.metric import Metric @@ -942,6 +943,8 @@ def maybe_convert(x): # Change the layout for the layer output if needed. # This is useful for relayout intermediate tensor in the model # to achieve the optimal performance. + from keras.src.distribution import distribution_lib + distribution = distribution_lib.distribution() if distribution is not None: current_layer_path = current_path() From 495f95a3d50e6bb2bb3dc2cbf1dc1dceafcc577a Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 16 Oct 2025 00:42:29 +0530 Subject: [PATCH 6/8] PR1,2,3 --- keras/src/backend/jax/__init__.py | 1 + keras/src/backend/jax/distributed_backend.py | 95 +++ .../src/backend/torch/distributed_backend.py | 266 +++---- .../backend/torch/distributed_backend_test.py | 80 +-- .../tensor_parallel/autoconfig.py | 282 ++++++++ .../tensor_parallel/autoconfig_test.py | 245 ------- .../tensor_parallel/communications_test.py | 84 --- .../tensor_parallel/config_test.py | 96 --- .../tensor_parallel/coordinated_optimizer.py | 653 ++++++++++++++++++ .../coordinated_optimizer_test.py | 180 ----- .../tensor_parallel/parameter_sharding.py | 536 +++++--------- .../parameter_sharding_test.py | 19 +- .../state_action_keras_test.py | 109 --- .../tensor_parallel/tensor_layout.py | 154 +++++ 14 files changed, 1466 insertions(+), 1334 deletions(-) create mode 100644 keras/src/backend/jax/distributed_backend.py create mode 100644 keras/src/distribution/tensor_parallel/autoconfig.py delete mode 100644 keras/src/distribution/tensor_parallel/autoconfig_test.py delete mode 100644 keras/src/distribution/tensor_parallel/communications_test.py delete mode 100644 keras/src/distribution/tensor_parallel/config_test.py create mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer.py delete mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py delete mode 100644 keras/src/distribution/tensor_parallel/state_action_keras_test.py create mode 100644 keras/src/distribution/tensor_parallel/tensor_layout.py diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..0a275fb70cf1 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,5 +1,6 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core +from keras.src.backend.jax import distributed_backend from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py new file mode 100644 index 000000000000..3ed50d756250 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend.py @@ -0,0 +1,95 @@ +import jax +import jax.lax as lax + + +def get_device_info(): + """Retrieves information about the available JAX devices. + + This function queries the JAX backend to identify the type and number + of available computational devices (e.g., CPU, GPU, TPU). + + Returns: + dict: A dictionary containing the backend name ('jax'), a list of + device string representations, and the total count of devices. + """ + available_devices = jax.devices() + return { + "backend": "jax", + "devices": [str(d) for d in available_devices], + "device_count": len(available_devices), + } + + +def is_multi_device_capable(): + """Checks if more than one JAX device is available for computation. + + Returns: + bool: True if the local JAX environment has more than one device, + False otherwise. + """ + return jax.local_device_count() > 1 + + +def get_communication_ops(): + """Provides a dictionary of JAX collective communication operations. + + Returns: + dict: A dictionary mapping operation names (e.g., 'all_reduce') to their + corresponding JAX implementation functions. + """ + + def all_reduce(x, op="sum", axis_name="model"): + """Reduces a tensor across a device mesh axis using a collective. + + This function assumes it is called within a `pjit` context that has a + device mesh with the specified `axis_name`. It performs a collective + reduction operation (like sum or mean) across all devices mapped to + that axis. + + Args: + x (jax.Array): The input JAX array (tensor) on the local device. + op (str, optional): The reduction operation to perform. Supported + values are 'sum' and 'mean'. Defaults to 'sum'. + axis_name (str, optional): The name of the mapped axis in the device + mesh over which to communicate. Defaults to 'model'. + + Returns: + jax.Array: The reduced JAX array, which is identical across all + devices participating in the reduction. + """ + if op == "sum": + return lax.psum(x, axis_name=axis_name) + elif op == "mean": + return lax.pmean(x, axis_name=axis_name) + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." + ) + + def all_gather(x, axis, axis_name="model"): + """Gathers and concatenates tensors from all devices across a mesh axis. + + This function assumes it is called within a `pjit` context. It takes + the local shard `x` from each device along the `axis_name` of the mesh + and concatenates them along the specified tensor `axis` to form a + single, larger tensor that is then replicated on all + participating devices. + + Args: + x (jax.Array): The input JAX array (tensor) shard on local device. + axis (int): The tensor axis along which to concatenate the gathered + shards. + axis_name (str, optional): The name of the mesh axis to gather + from. Defaults to 'model'. + + Returns: + jax.Array: The full, gathered JAX array, which is identical across + all devices participating in the gather. + """ + return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + } diff --git a/keras/src/backend/torch/distributed_backend.py b/keras/src/backend/torch/distributed_backend.py index 432760339102..af00b07f87fe 100644 --- a/keras/src/backend/torch/distributed_backend.py +++ b/keras/src/backend/torch/distributed_backend.py @@ -1,220 +1,142 @@ -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Literal +import os import torch import torch.distributed as dist -def compute_gradients( - loss: torch.Tensor, trainable_vars: List[torch.Tensor] -) -> List[torch.Tensor]: - """Computes gradients of the loss with respect to trainable variables. - - This function leverages PyTorch's `autograd.grad` for a stateless, - functional approach similar to `jax.grad`. - - Args: - loss (torch.Tensor): The loss value for which to compute gradients. - trainable_vars (List[torch.Tensor]): A list of variables (tensors with - `requires_grad=True`) to compute gradients with respect to. - - Returns: - List[torch.Tensor]: A list of gradients corresponding to the - trainable variables. - """ - return list(torch.autograd.grad(loss, trainable_vars)) - - -def apply_gradients( - gradients: List[torch.Tensor], - trainable_vars: List[torch.Tensor], - learning_rate: float = 0.001, -) -> List[torch.Tensor]: - """Applies gradients and returns the updated variables. - - Updates are performed in-place within a `torch.no_grad()` context - to prevent the update operation from being part of the computation graph. - """ - with torch.no_grad(): - updated_vars = [] - for grad, var in zip(gradients, trainable_vars): - if grad is not None: - var.sub_(learning_rate * grad) - updated_vars.append(var) - return updated_vars - - -def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: - """Creates a configuration dictionary for a PyTorch optimizer. - - This function returns a dictionary containing the optimizer's configuration, - maintaining a consistent interface with the JAX backend. The user is - expected to instantiate the optimizer from this config. - - Args: - optimizer_class (str): The name of the optimizer to create (e.g., - `"adam"`, `"sgd"`). - **kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`). - - Returns: - Dict[str, Any]: A dictionary representing the optimizer configuration. - """ - config = kwargs.copy() - config["name"] = optimizer_class.lower() - config.setdefault("learning_rate", 0.001) - return config - - -def get_device_info() -> Dict[str, Any]: +def get_device_info(): """Retrieves information about the available PyTorch devices. + This function queries PyTorch to identify the type and number of + available computational devices (e.g., CPU, GPU). + Returns: - Dict[str, Any]: A dictionary containing the backend name, a list of - available device strings, and the total device count. + dict: A dictionary containing the backend name ('torch'), a list of + device string representations, and the total count of devices. """ if torch.cuda.is_available(): device_count = torch.cuda.device_count() - devices = [torch.cuda.get_device_name(i) for i in range(device_count)] + devices = [ + f"cuda:{i} ({torch.cuda.get_device_name(i)})" + for i in range(device_count) + ] + backend = "torch (CUDA)" else: device_count = 1 devices = ["cpu"] + backend = "torch (CPU)" + return { - "backend": "pytorch", + "backend": backend, "devices": devices, "device_count": device_count, } -def is_multi_device_capable() -> bool: - """Checks if more than one CUDA device is available. +def is_multi_device_capable(): + """Checks if more than one device is available for distributed computation. Returns: - bool: `True` if PyTorch reports more than one CUDA device, `False` - otherwise. + bool: True if the PyTorch distributed environment is initialized and + has a world size greater than one, False otherwise. + """ + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() > 1 + elif torch.cuda.is_available(): + return torch.cuda.device_count() > 1 + return False + + +def setup_distributed_environment(): + """ + A helper function to initialize the distributed process group. + + This is a prerequisite for using the communication operations. + In a real application, this would be called at the start of the script. + It uses environment variables commonly set by launchers like torchrun. """ - return torch.cuda.device_count() > 1 + if dist.is_available() and not dist.is_initialized(): + required_env_vars = ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE"] + if not all(v in os.environ for v in required_env_vars): + return False + + dist.init_process_group(backend="nccl") + return True + elif dist.is_initialized(): + return True + else: + return False -def get_communication_ops() -> Dict[str, Callable]: +def get_communication_ops(): """Provides a dictionary of PyTorch collective communication operations. - These operations rely on the `torch.distributed` package. They are - designed to work in a multi-process, multi-device environment. If the - distributed package is not initialized, they provide a sensible fallback - for single-device execution. + Note: The torch.distributed process group must be initialized before + calling these functions. Returns: - Dict[str, Callable]: A dictionary mapping operation names to their - PyTorch implementations. + dict: A dictionary mapping operation names (e.g., 'all_reduce') to their + corresponding PyTorch implementation functions. """ - def _is_distributed() -> bool: - """Checks if the default process group is initialized.""" - return dist.is_available() and dist.is_initialized() - - def all_reduce( - x: torch.Tensor, - op: Literal["sum", "mean"] = "sum", - axis_name: str = None, - ) -> torch.Tensor: - """Reduces a tensor across all devices.""" - if not _is_distributed(): - world_size = ( - torch.cuda.device_count() if torch.cuda.is_available() else 1 + def all_reduce(x, op="sum"): + """Reduces a tensor across all devices in the process group. + + This function performs a collective reduction operation + across all devices in the distributed group. + + Args: + x (torch.Tensor): The input tensor on the local device. + op (str, optional): The reduction operation to perform. Supported + values are 'sum' and 'mean'. Defaults to 'sum'. + + Returns: + torch.Tensor: The reduced tensor, which is identical across all + devices participating in the reduction. + """ + if not (dist.is_available() and dist.is_initialized()): + return x + + if op == "sum": + reduce_op = dist.ReduceOp.SUM + elif op == "mean": + reduce_op = dist.ReduceOp.AVG + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." ) - if world_size <= 1: - return x - if op == "sum": - return x * float(world_size) - elif op == "mean": - return x - else: - raise ValueError(f"Unsupported all_reduce op: {op}") - - reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get( - op - ) - if reduce_op is None: - raise ValueError(f"Unsupported all_reduce op: {op}") result = x.clone() dist.all_reduce(result, op=reduce_op) return result - def all_gather( - x: torch.Tensor, axis: int = 0, axis_name: str = None - ) -> torch.Tensor: - """Gathers tensors from all devices and concatenates them.""" - if not _is_distributed(): - world_size = ( - torch.cuda.device_count() if torch.cuda.is_available() else 1 - ) - if world_size <= 1: - return x - return torch.cat([x] * world_size, dim=axis) + def all_gather(x, axis): + """Gathers and concatenates tensors from all devices. - world_size = dist.get_world_size() - tensor_list = [torch.empty_like(x) for _ in range(world_size)] - dist.all_gather(tensor_list, x) - return torch.cat(tensor_list, dim=axis) + This function takes the local tensor `x` from each device and + concatenates them along the specified tensor `axis` to form a single, + larger tensor that is then replicated on all participating devices. - def broadcast( - x: torch.Tensor, root: int = 0, axis_name: str = None - ) -> torch.Tensor: - """Broadcasts a tensor from a root device to all other devices.""" - if not _is_distributed(): - return x + Args: + x (torch.Tensor): The input tensor shard on the local device. + axis (int): The tensor axis along which to concatenate the gathered + shards. - dist.broadcast(x, src=root) - return x - - def scatter( - x: torch.Tensor, - root: int = 0, - axis: int = 0, - axis_name: str = None, - ) -> torch.Tensor: - """Scatters a tensor from a root device to all devices.""" - if not _is_distributed(): - world_size = ( - torch.cuda.device_count() if torch.cuda.is_available() else 1 - ) - if world_size <= 1: - return x - if x.shape[axis] % world_size != 0: - raise ValueError( - f"Tensor with shape {x.shape} cannot be scattered along " - f"axis {axis} across {world_size} devices." - ) - return torch.chunk(x, world_size, dim=axis)[0] + Returns: + torch.Tensor: The full, gathered tensor, which is identical across + all devices participating in the gather. + """ + if not (dist.is_available() and dist.is_initialized()): + return x world_size = dist.get_world_size() - rank = dist.get_rank() - - if x.shape[axis] % world_size != 0: - raise ValueError( - f"Tensor with shape {x.shape} cannot be scattered along " - f"axis {axis} across {world_size} devices." - ) - - if rank == root: - scatter_list = list(torch.chunk(x, world_size, dim=axis)) - else: - scatter_list = None - - chunk_shape = list(x.shape) - chunk_shape[axis] //= world_size - local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device) + tensor_list = [torch.empty_like(x) for _ in range(world_size)] - dist.scatter(local_chunk, scatter_list, src=root) - return local_chunk + dist.all_gather(tensor_list, x) + return torch.cat(tensor_list, dim=axis) return { "all_reduce": all_reduce, "all_gather": all_gather, - "broadcast": broadcast, - "scatter": scatter, } diff --git a/keras/src/backend/torch/distributed_backend_test.py b/keras/src/backend/torch/distributed_backend_test.py index d6dce5977b81..cbf4766b1c9c 100644 --- a/keras/src/backend/torch/distributed_backend_test.py +++ b/keras/src/backend/torch/distributed_backend_test.py @@ -12,75 +12,10 @@ class TestPytorchDistributedFunctions: """Unit tests for the PyTorch distributed backend standalone functions.""" - def test_compute_gradients_computes_correctly(self): - """Test that compute_gradients returns correct gradients.""" - w = torch.tensor([2.0, 3.0], requires_grad=True) - b = torch.tensor(1.0, requires_grad=True) - x = torch.tensor([4.0, 5.0]) - y_true = torch.tensor(25.0) - - # loss = (w.x + b - y_true)^2 = ((2*4 + 3*5 + 1) - 25)^2 = (24-25)^2 = 1 - y_pred = torch.dot(w, x) + b - loss = (y_pred - y_true) ** 2 - - trainable_vars = [w, b] - gradients = distributed_backend.compute_gradients(loss, trainable_vars) - - # d_loss/d_w = 2*(y_pred - y_true)*x = 2*(-1)*[4, 5] = [-8, -10] - # d_loss/d_b = 2*(y_pred - y_true)*1 = 2*(-1)*1 = -2 - expected_grad_w = torch.tensor([-8.0, -10.0]) - expected_grad_b = torch.tensor(-2.0) - - assert len(gradients) == 2 - torch.testing.assert_close(gradients[0], expected_grad_w) - torch.testing.assert_close(gradients[1], expected_grad_b) - - def test_apply_gradients(self): - """Test the application of gradients to PyTorch tensors.""" - var1 = torch.tensor([1.0, 2.0], requires_grad=True) - var2 = torch.tensor(5.0, requires_grad=True) - trainable_vars = [var1, var2] - grad1 = torch.tensor([0.1, 0.2]) - grad2 = torch.tensor(0.5) - gradients = [grad1, grad2] - learning_rate = 0.1 - - original_var1 = var1.clone() - original_var2 = var2.clone() - - updated_vars = distributed_backend.apply_gradients( - gradients, trainable_vars, learning_rate - ) - - assert updated_vars[0] is var1 - assert updated_vars[1] is var2 - - expected_var1 = original_var1 - (grad1 * learning_rate) - expected_var2 = original_var2 - (grad2 * learning_rate) - torch.testing.assert_close(updated_vars[0], expected_var1) - torch.testing.assert_close(updated_vars[1], expected_var2) - - def test_create_optimizer(self): - """Test optimizer configuration creation.""" - adam_config = distributed_backend.create_optimizer( - "adam", learning_rate=0.01 - ) - assert isinstance(adam_config, dict) - assert adam_config["name"] == "adam" - assert adam_config["learning_rate"] == 0.01 - - sgd_config = distributed_backend.create_optimizer( - "sgd", learning_rate=0.1, momentum=0.9 - ) - assert isinstance(sgd_config, dict) - assert sgd_config["name"] == "sgd" - assert sgd_config["learning_rate"] == 0.1 - assert sgd_config["momentum"] == 0.9 - def test_get_device_info(self): """Test retrieving device information from the PyTorch backend.""" info = distributed_backend.get_device_info() - assert info["backend"] == "pytorch" + assert info["backend"] == "torch (CPU)" assert isinstance(info["devices"], list) assert isinstance(info["device_count"], int) assert info["device_count"] > 0 @@ -114,16 +49,3 @@ def test_communication_ops_simulation_logic(self): gathered = comm_ops["all_gather"](x_gather, axis=0) expected_gather = torch.cat([x_gather] * world_size, dim=0) torch.testing.assert_close(gathered, expected_gather) - - # Test broadcast - x_broadcast = torch.tensor([5.0, 6.0]) - broadcasted = comm_ops["broadcast"](x_broadcast) - torch.testing.assert_close(broadcasted, x_broadcast) - - # Test scatter - if world_size > 0: - scatter_data = torch.arange(world_size * 4, dtype=torch.float32) - x_scatter = scatter_data.reshape(world_size * 2, 2) - scattered = comm_ops["scatter"](x_scatter, axis=0) - expected_scatter = torch.chunk(x_scatter, world_size, dim=0)[0] - torch.testing.assert_close(scattered, expected_scatter) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py new file mode 100644 index 000000000000..708d6d603cc6 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -0,0 +1,282 @@ +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split + + +def analyze_dense_layer_directly(layer, module, prefix): + """Analyzes a Keras Dense layer to classify its sharding strategy. + + This function inspects the input and output dimensions of a Dense layer + to determine if it functions as an expansion layer ("up-projection"), a + contraction layer ("down-projection"), or neither ("generic_dense"). This + classification is a heuristic commonly used to apply tensor parallelism + in Transformer-based models, such as in an MLP block where an up-projection + is followed by a down-projection. + + Args: + layer: The Keras `layers.Dense` instance to analyze. + module: The parent module containing the layer (currently unused). + prefix (str): The name prefix for the layer in the model hierarchy + (currently unused). + + Returns: + str: A string classifying the layer as 'up_projection', + 'down_projection', or 'generic_dense'. + """ + from keras.src import layers + + if not isinstance(layer, layers.Dense): + return "generic_dense" + + input_dim = None + output_dim = None + + if hasattr(layer, "kernel") and layer.kernel is not None: + kernel_shape = layer.kernel.shape + if len(kernel_shape) == 2: + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] + + if input_dim is None or output_dim is None: + if hasattr(layer, "units"): + output_dim = layer.units + else: + return "generic_dense" + + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): + input_dim = layer.input_shape[-1] + else: + return "generic_dense" + + if not input_dim or not output_dim: + return "generic_dense" + + expansion_threshold = 1.5 + is_expansion = output_dim > input_dim * expansion_threshold + is_contraction = input_dim > output_dim * expansion_threshold + + if is_expansion: + return "up_projection" + elif is_contraction: + return "down_projection" + else: + return "generic_dense" + + +def _find_and_shard_layers( + current_layer, + prefix, + module, + world_size, + state_rules, + output_rules, + processed_layers, +): + """Recursively traverses the model graph to apply sharding rules. + + This function walks through all nested layers of a given Keras model or + layer. For each encountered layer, it determines an appropriate tensor + parallelism strategy and populates the `state_rules` and `output_rules` + dictionaries with the corresponding sharding actions. It uses a set of + processed layer IDs to avoid redundant processing of shared layers. + + The sharding logic is as follows: + - `Dense` layers are sharded based on their classification (up/down proj). + - Up-projections are split along the column axis (output features). + - Down-projections are split along the row axis (input features). + - `EinsumDense` layers in attention blocks are sharded similarly. + - `Embedding` layers are sharded column-wise for vocabulary parallelism. + - Normalization layers are ignored (replicated on all devices). + + Args: + current_layer: The Keras layer currently being processed. + prefix (str): The hierarchical name prefix for the `current_layer`. + module: The top-level Keras model or layer being configured. + world_size (int): The total number of devices for sharding. + state_rules (Dict[str, Any]): A dictionary to be populated with rules + for sharding layer weights (state). Keys are regex patterns + matching weight names, values are `SplitKeras` actions. + output_rules (Dict[str, Any]): A dictionary to be populated with rules + for handling layer outputs. Keys are regex patterns matching layer + names, values describe the communication op (e.g., 'allreduce'). + processed_layers (Set[int]): A set of `id()`s of layers that have + already been processed to prevent cycles and redundant work. + """ + from keras.src import layers + + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(current_layer, layers.Dense): + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) + + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = Split( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = Split( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather"} + + elif mlp_type == "down_projection": + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") + output_rules[f"^{full_name}$"] = {0: "allreduce"} + + else: + state_rules[f"^{full_name}.kernel$"] = Split( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = Split( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather -1"} + return + + elif isinstance(current_layer, layers.EinsumDense): + if "attention_output" in full_name: + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + pass + output_rules[f"^{full_name}$"] = {0: "allreduce"} + else: + state_rules[f"^{full_name}.kernel$"] = Split( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = Split( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather -1"} + return + + elif isinstance(current_layer, (layers.Embedding,)): + if hasattr(current_layer, "token_embedding") or hasattr( + current_layer, "position_embedding" + ): + pass + else: + weight_name = None + if hasattr(current_layer, "embeddings"): + weight_name = "embeddings" + elif hasattr(current_layer, "position_embeddings"): + weight_name = "position_embeddings" + + if weight_name: + state_rules[f"^{full_name}\\..*{weight_name}$"] = Split( + world_size, 1, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): + return + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + _find_and_shard_layers( + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + if hasattr(current_layer, attr_name): + attr = getattr(current_layer, attr_name) + + if isinstance(attr, layers.Layer) and attr is not current_layer: + _find_and_shard_layers( + attr, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + _find_and_shard_layers( + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + + +def get_default_config_keras(module, device_ids): + """Generates default tensor parallelism sharding configuration for a model. + + This function serves as entry point for automatically creating a sharding + plan for a given Keras model or layer. It initializes the rule dictionaries + and starts the recursive layer traversal to populate them based on a default + set of heuristics for common architectures like Transformers. + + Example: + ```python + model = MyTransformerModel() + device_ids = ["gpu:0", "gpu:1"] + sharding_config = get_default_config_keras(model, device_ids) + # sharding_config can now be used to distribute the model + ``` + + Args: + module: The Keras `Model` or `Layer` to generate a config for. + device_ids (Sequence[str]): A sequence of device IDs (e.g., + ["gpu:0", "gpu:1"]) across which the model will be sharded. + + Returns: + ConfigKeras: A configuration object containing the generated sharding + rules for model weights (`state_rules`) and layer outputs + (`output_rules`). + """ + world_size = len(device_ids) + state_rules = {} + output_rules = {} + processed_layers = set() + + _find_and_shard_layers( + current_layer=module, + prefix="", + module=module, + world_size=world_size, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + ) + + return LayoutMap(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py deleted file mode 100644 index 470c00774660..000000000000 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ /dev/null @@ -1,245 +0,0 @@ -import os - -import pytest - -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" - -from keras import Input -from keras import Model -from keras import layers -from keras.src import backend -from keras.src import testing -from keras.src.distribution import distributed_backend -from keras.src.distribution.tensor_parallel.autoconfig import ( - analyze_dense_layer_directly, -) -from keras.src.distribution.tensor_parallel.autoconfig import ( - get_default_config_keras, -) -from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras - - -@pytest.mark.skipif( - backend.backend() not in ("torch", "jax") - or distributed_backend.get_device_info()["device_count"] <= 1, - reason="This test is for JAX/PyTorch backends and requires > 1 device.", -) -class TestAutoConfigKeras(testing.TestCase): - def setUp(self): - """Set up the test case and common variables.""" - super().setUp() - device_info = distributed_backend.get_device_info() - self.world_size = device_info["device_count"] - self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] - - self.assertGreater( - self.world_size, 1, "Distribution tests require more than 1 device." - ) - - def _assert_split_keras_equal(self, rule1, rule2): - """Helper to compare two SplitKeras objects by their attributes.""" - self.assertIsInstance(rule1, SplitKeras) - self.assertIsInstance(rule2, SplitKeras) - self.assertDictEqual(vars(rule1), vars(rule2)) - - def _assert_rules_equal(self, actual_rules, expected_rules): - """Helper to compare two dictionaries of sharding rules.""" - self.assertSetEqual( - set(actual_rules.keys()), set(expected_rules.keys()) - ) - for key in expected_rules: - actual_val = actual_rules[key] - expected_val = expected_rules[key] - if isinstance(expected_val, SplitKeras): - self._assert_split_keras_equal(actual_val, expected_val) - else: - self.assertEqual(actual_val, expected_val) - - def test_analyze_dense_layer(self): - """Tests the direct analysis and classification of Dense layers.""" - up_proj_layer = layers.Dense(32) - up_proj_layer.build(input_shape=(None, 16)) - self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), - "up_projection", - ) - - down_proj_layer = layers.Dense(16) - down_proj_layer.build(input_shape=(None, 32)) - self.assertEqual( - analyze_dense_layer_directly(down_proj_layer, None, ""), - "down_projection", - ) - - generic_layer = layers.Dense(20) - generic_layer.build(input_shape=(None, 16)) - self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), - "generic_dense", - ) - - def test_simple_mlp_sharding(self): - """Tests a simple MLP with up and down projection layers.""" - inputs = Input(shape=(64,)) - x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( - x - ) - model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^simple_mlp.up_projection_layer.bias$": SplitKeras( - self.world_size, 0, "column" - ), - r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( - self.world_size, 0, "row" - ), - # Bias for down-projection is not sharded according to the new logic - } - expected_output_rules = { - r"^simple_mlp.up_projection_layer$": {0: "gather"}, - r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_generic_dense_sharding(self): - """Tests a generic Dense layer that isn't an up/down projection.""" - inputs = Input(shape=(64,)) - outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) - model = Model(inputs=inputs, outputs=outputs, name="generic_model") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^generic_model.generic_layer.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^generic_model.generic_layer.bias$": SplitKeras( - self.world_size, 0, "column" - ), - } - expected_output_rules = { - r"^generic_model.generic_layer$": {0: "gather -1"} - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_embedding_sharding(self): - """Tests an Embedding layer for vocabulary parallelism.""" - inputs = Input(shape=(10,), dtype="int32") - outputs = layers.Embedding( - input_dim=1000, output_dim=128, name="token_embedding" - )(inputs) - model = Model(inputs=inputs, outputs=outputs, name="embed_model") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - # FIX: Removed the incorrect backslash before ".token_embedding" - r"^embed_model.token_embedding\..*embeddings$": SplitKeras( - self.world_size, 1, "column" - ) - } - expected_output_rules = { - r"^embed_model.token_embedding$": {0: "no_comm"} - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_einsum_dense_sharding(self): - """Tests the special handling for EinsumDense layers.""" - inputs = Input(shape=(64,)) - x = layers.EinsumDense( - "bh,hd->bd", output_shape=128, name="query_proj" - )(inputs) - outputs = layers.EinsumDense( - "bd,dh->bh", output_shape=64, name="attention_output" - )(x) - model = Model(inputs=inputs, outputs=outputs, name="einsum_model") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^einsum_model.query_proj.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^einsum_model.attention_output.kernel$": SplitKeras( - self.world_size, 0, "row" - ), - } - expected_output_rules = { - r"^einsum_model.query_proj$": {0: "gather -1"}, - r"^einsum_model.attention_output$": {0: "allreduce"}, - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_normalization_layers_ignored(self): - """Tests that normalization layers are correctly ignored.""" - inputs = Input(shape=(64,)) - x = layers.Dense(64, name="dense1", use_bias=True)(inputs) - x = layers.LayerNormalization(name="layernorm")(x) - outputs = layers.Dense(64, name="dense2", use_bias=True)(x) - model = Model(inputs=inputs, outputs=outputs, name="norm_model") - - config = get_default_config_keras(model, self.device_ids) - - for key in config.state_rules: - self.assertNotIn("layernorm", key) - for key in config.output_rules: - self.assertNotIn("layernorm", key) - - self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) - self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) - self.assertEqual(len(config.state_rules), 4) - self.assertEqual(len(config.output_rules), 2) - - def test_nested_model_sharding(self): - """Tests that the traversal logic correctly handles nested models.""" - inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( - inner_inputs - ) - inner_model = Model( - inputs=inner_inputs, outputs=inner_outputs, name="inner_block" - ) - - outer_inputs = Input(shape=(32,)) - x = inner_model(outer_inputs) - outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) - outer_model = Model( - inputs=outer_inputs, outputs=outer_outputs, name="outer_model" - ) - - config = get_default_config_keras(outer_model, self.device_ids) - - expected_state_rules = { - r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( - self.world_size, 0, "column" - ), - r"^outer_model.outer_dense.kernel$": SplitKeras( - self.world_size, 0, "row" - ), - # Bias for down-projection is not sharded according to the new logic - } - expected_output_rules = { - r"^outer_model.inner_block.inner_dense$": {0: "gather"}, - r"^outer_model.outer_dense$": {0: "allreduce"}, - } - - self.maxDiff = None - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/communications_test.py b/keras/src/distribution/tensor_parallel/communications_test.py deleted file mode 100644 index 1b6d95d37664..000000000000 --- a/keras/src/distribution/tensor_parallel/communications_test.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest - -import keras -from keras.src import backend -from keras.src import testing -from keras.src.backend import distributed_backend -from keras.src.distribution.tensor_parallel.communications import AllGatherKeras -from keras.src.distribution.tensor_parallel.communications import AllReduceKeras -from keras.src.distribution.tensor_parallel.communications import BroadcastKeras -from keras.src.distribution.tensor_parallel.communications import ( - TensorParallelCommunicator, -) - - -@pytest.mark.skipif( - backend.backend() not in ("torch", "jax"), - reason="This test is for JAX/PyTorch backends.", -) -class TestCollectiveOps(testing.TestCase): - """ - Tests collective communication ops on a JAX distributed backend. - """ - - def setUp(self): - super().setUp() - device_info = distributed_backend.get_device_info() - self.world_size = device_info.get("device_count", 1) - - if not self.world_size: - self.world_size = 1 - - self.axis_name = "data" - - def test_all_reduce(self): - """Tests the all-reduce operation.""" - all_reduce_op = AllReduceKeras(world_size=self.world_size, op="sum") - local_tensor = keras.ops.array([1.0, 2.0, 3.0]) - - result = all_reduce_op(local_tensor, axis_name=self.axis_name) - - expected_output = keras.ops.multiply( - local_tensor, float(self.world_size) - ) - self.assertAllClose(result, expected_output) - - def test_all_gather(self): - """Tests the all-gather operation.""" - all_gather_op = AllGatherKeras(world_size=self.world_size, dim=0) - local_slice = keras.ops.arange(6, dtype="float32").reshape((2, 3)) - result = all_gather_op(local_slice, axis_name=self.axis_name) - - expected_output = keras.ops.concatenate( - [local_slice] * self.world_size, axis=0 - ) - self.assertAllClose(result, expected_output) - - def test_broadcast(self): - """Tests the broadcast operation.""" - broadcast_op = BroadcastKeras( - world_size=self.world_size, src_rank=0, rank=0 - ) - tensor_to_broadcast = keras.ops.array([5.0, 10.0, 15.0]) - result = broadcast_op(tensor_to_broadcast, axis_name=self.axis_name) - - self.assertAllClose(result, tensor_to_broadcast) - - def test_tensor_parallel_communicator_forward_column_parallel(self): - """Tests the communicator's all-gather for column-parallel forward.""" - communicator = TensorParallelCommunicator( - world_size=self.world_size, rank=0 - ) - - local_slice = keras.ops.array([[0.0, 1.0], [2.0, 3.0]], dtype="float32") - - result = communicator.forward_column_parallel( - partial_outputs=[local_slice], - dim=0, - axis_name=self.axis_name, - ) - - expected_output = keras.ops.concatenate( - [local_slice] * self.world_size, axis=0 - ) - self.assertAllClose(result, expected_output) diff --git a/keras/src/distribution/tensor_parallel/config_test.py b/keras/src/distribution/tensor_parallel/config_test.py deleted file mode 100644 index cbb26e40e6db..000000000000 --- a/keras/src/distribution/tensor_parallel/config_test.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest - -from keras.src import backend -from keras.src import testing -from keras.src.distribution.tensor_parallel.communications import AllGatherKeras -from keras.src.distribution.tensor_parallel.communications import AllReduceKeras -from keras.src.distribution.tensor_parallel.communications import BroadcastKeras -from keras.src.distribution.tensor_parallel.config import ConfigKeras -from keras.src.distribution.tensor_parallel.config import _create_ops_from_rules - - -@pytest.mark.skipif( - backend.backend() not in ("torch", "jax"), - reason="This test is for JAX/PyTorch backends.", -) -class TestConfig(testing.TestCase): - """Test suite for the tensor parallel configuration.""" - - def test_create_ops_from_rules_helper(self): - """ - Tests the private _create_ops_from_rules helper function directly - to ensure it correctly parses various rule types. - """ - devices = ["/gpu:0", "/gpu:1"] - world_size = len(devices) - rules = { - "dense/kernel": {"forward": "sum", "backward": "mean"}, - "embedding/weight": { - "forward": "gather 0", - "backward": "gather -1", - }, - "attention/dense/bias": {"forward": "broadcast"}, - "passthrough": {"action": 123}, - "no_dict_action": "identity", - } - - processed_rules = _create_ops_from_rules(rules, world_size) - - sum_op = processed_rules["dense/kernel"]["forward"] - self.assertIsInstance(sum_op, AllReduceKeras) - self.assertEqual(sum_op.op, "sum") - self.assertEqual(sum_op.world_size, world_size) - - mean_op = processed_rules["dense/kernel"]["backward"] - self.assertIsInstance(mean_op, AllReduceKeras) - self.assertEqual(mean_op.op, "mean") - - gather_op_0 = processed_rules["embedding/weight"]["forward"] - self.assertIsInstance(gather_op_0, AllGatherKeras) - self.assertEqual(gather_op_0.dim, 0) - self.assertEqual(gather_op_0.world_size, world_size) - - gather_op_neg1 = processed_rules["embedding/weight"]["backward"] - self.assertIsInstance(gather_op_neg1, AllGatherKeras) - self.assertEqual(gather_op_neg1.dim, -1) - - broadcast_op = processed_rules["attention/dense/bias"]["forward"] - self.assertIsInstance(broadcast_op, BroadcastKeras) - self.assertEqual(broadcast_op.world_size, world_size) - - self.assertEqual(processed_rules["passthrough"]["action"], 123) - self.assertEqual(processed_rules["no_dict_action"], "identity") - - def test_config_keras_create_collective_ops(self): - """ - Tests the public create_collective_ops method of the ConfigKeras class. - """ - devices = ["/gpu:0", "/gpu:1"] - world_size = len(devices) - - state_rules = {"some_weight": "split"} - output_rules = { - "layer_1_output": {"activation": "sum"}, - "layer_2_output": {"activation": "gather -1"}, - } - - config = ConfigKeras(state_rules=state_rules, output_rules=output_rules) - new_config = config.create_collective_ops(devices) - - self.assertIsNot(new_config, config) - - self.assertEqual(new_config.state_rules, state_rules) - - self.assertIsInstance( - config.output_rules["layer_1_output"]["activation"], str - ) - - sum_op = new_config.output_rules["layer_1_output"]["activation"] - self.assertIsInstance(sum_op, AllReduceKeras) - self.assertEqual(sum_op.op, "sum") - self.assertEqual(sum_op.world_size, world_size) - - gather_op = new_config.output_rules["layer_2_output"]["activation"] - self.assertIsInstance(gather_op, AllGatherKeras) - self.assertEqual(gather_op.dim, -1) - self.assertEqual(gather_op.world_size, world_size) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py new file mode 100644 index 000000000000..62039e2e121f --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -0,0 +1,653 @@ +import re + +import numpy as np + +import keras +from keras.src import ops +from keras.src import optimizers +from keras.src.backend import distributed_backend + + +class CoordinatedOptimizer: + """Manages an optimizer's state for distributed training. + + This class is an internal coordinator that handles the complexities of + sharding optimizer states across multiple devices (shards) and + synchronizing gradients according to tensor parallelism rules. It is not + intended to be used directly by the end-user but is a core component of + the `TensorParallelOptimizer`. + + Args: + base_optimizer: The Keras optimizer instance + (e.g., `keras.optimizers.Adam`) whose state will be managed. + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + rank: The rank of the current process. Defaults to 0. + shard_optimizer_states: If `True`, the optimizer's state variables + (e.g., momentum, velocity) will be partitioned across `world_size` + devices. Defaults to `True`. + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism, such as which gradients to + all-reduce. Defaults to `None`. + """ + + def __init__( + self, + base_optimizer, + world_size, + distributed_backend="auto", + rank=0, + shard_optimizer_states=True, + tensor_parallel_config=None, + ): + """Initializes the CoordinatedOptimizer.""" + self.base_optimizer = base_optimizer + self.world_size = world_size + self.shard_optimizer_states = shard_optimizer_states + self.tensor_parallel_config = tensor_parallel_config + self.sharded_states = {} + self._state_variable_to_parameter = {} + self._variables = None + self._variable_to_slot_name = {} + + def _initialize_sharded_states(self): + """ + Partitions the optimizer's state variables across shards by inspecting + the variables created by the base optimizer. + """ + if not self.shard_optimizer_states or not self.base_optimizer.built: + return + + self.sharded_states = {} + self._state_variable_to_parameter = {} + self._variable_to_slot_name = {} + opt_name = self.base_optimizer.name + + normalized_params = sorted( + [(p.path.replace("/", "_"), p) for p in self._variables], + key=lambda x: len(x[0]), + reverse=True, + ) + + for state_var in self.base_optimizer.variables: + if state_var is self.base_optimizer.iterations: + continue + + path_parts = state_var.path.split("/") + if len(path_parts) != 2 or path_parts[0] != opt_name: + continue + + state_suffix = path_parts[1] + + found_param = None + slot_name = None + for norm_param_path, param in normalized_params: + if state_suffix.startswith(norm_param_path): + found_param = param + slot_suffix = state_suffix[len(norm_param_path) :] + slot_name = slot_suffix.strip("_") + break + + if found_param is not None and slot_name is not None: + self._state_variable_to_parameter[state_var.path] = found_param + self._variable_to_slot_name[state_var.path] = slot_name + + sharding_dim = 0 + if self.tensor_parallel_config: + norm_param_name = found_param.path.replace("/", ".") + for p, a in self.tensor_parallel_config.state_rules.items(): + if re.search(p, norm_param_name) and hasattr(a, "dim"): + sharding_dim = a.dim + break + + partitioned_state = self._partition_state( + state_var, dim=sharding_dim + ) + self.sharded_states.setdefault(slot_name, {})[ + found_param.path + ] = partitioned_state + + if self.base_optimizer.iterations is not None: + self.sharded_states["iterations"] = self._partition_state( + self.base_optimizer.iterations, dim=0 + ) + + def _partition_state(self, state_variable, dim): + """Splits a single state variable numpy array into chunks. + + If the variable cannot be split along the given dimension, it is + replicated across all shards. + + Args: + state_variable: The optimizer state variable. + dim: The dimension along which to partition the variable. + + Returns: + A list of NumPy arrays, where each array is a partition of the + original state variable for a specific shard. + """ + state_array = ops.convert_to_numpy(state_variable) + if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: + return np.array_split(state_array, self.world_size, axis=dim) + else: + return [np.copy(state_array) for _ in range(self.world_size)] + + def apply_gradients(self, grads_and_vars, shard_models): + """ + Applies gradients to the model variables by first synchronizing them + and then applying them using either sharded or replicated optimizer + states. + + Args: + grads_and_vars: A list of (gradient, variable) lists from all + shards. + shard_models: A list of the sharded model instances. + """ + synchronized_gradients = self._synchronize_gradients(grads_and_vars) + + if self.shard_optimizer_states: + self._apply_gradients_with_sharded_states( + synchronized_gradients, shard_models + ) + else: + self._apply_gradients_with_replicated_states( + synchronized_gradients, shard_models + ) + + def _apply_gradients_with_replicated_states( + self, synchronized_gradients, shard_models + ): + """Averages gradients across all shards and applies them once. + + This method is used when optimizer state sharding is disabled. It + calculates the average of the gradients for each variable across all + shards and applies the averaged gradients using the single, replicated + optimizer state. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + num_vars = len(synchronized_gradients[0]) + averaged_grads_and_vars = [] + + for i in range(num_vars): + variable = synchronized_gradients[0][i][1] + grads_for_var = [ + shard_grads[i][0] + for shard_grads in synchronized_gradients + if shard_grads[i][0] is not None + ] + + if not grads_for_var: + continue + + if len(grads_for_var) > 1: + stacked_grads = ops.stack(grads_for_var, axis=0) + averaged_grad = ops.mean(stacked_grads, axis=0) + else: + averaged_grad = grads_for_var[0] + + averaged_grads_and_vars.append((averaged_grad, variable)) + + if averaged_grads_and_vars: + self.base_optimizer.apply_gradients(averaged_grads_and_vars) + + def _apply_gradients_with_sharded_states( + self, synchronized_gradients, shard_models + ): + """Applies gradients to each shard using its local optimizer state. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + for shard_idx in range(self.world_size): + local_states = self._get_local_optimizer_states(shard_idx) + shard_optimizer = shard_models[shard_idx].optimizer + + self._update_optimizer_internal_state(shard_optimizer, local_states) + + shard_grads_and_vars = synchronized_gradients[shard_idx] + shard_optimizer.apply_gradients(shard_grads_and_vars) + + self._update_global_sharded_states(shard_optimizer, shard_idx) + + def _get_local_optimizer_states(self, shard_idx): + """Constructs the state dictionary for a single shard. + + Args: + shard_idx: The index of the shard for which to retrieve the state. + + Returns: + A dictionary containing the local optimizer state for the specified + shard. + """ + local_states = {} + for state_name, state_value in self.sharded_states.items(): + if isinstance(state_value, dict): + local_states[state_name] = {} + for param_name, param_states in state_value.items(): + local_states[state_name][param_name] = param_states[ + shard_idx + ] + else: + local_states[state_name] = state_value[shard_idx] + return local_states + + def _update_optimizer_internal_state(self, optimizer, local_states): + """Assigns local sharded state values to the optimizer's variables. + + Args: + optimizer: The optimizer instance for a specific shard. + local_states: The dictionary of local states for that shard. + """ + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + if "iterations" in local_states: + var.assign(local_states["iterations"]) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + var.assign(local_param_state) + + def _update_global_sharded_states(self, optimizer, shard_idx): + """Updates the main sharded_states dictionary after a gradient step. + + Args: + optimizer: The optimizer instance for a specific shard. + shard_idx: The index of the shard that was updated. + """ + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + self.sharded_states["iterations"][shard_idx] = ( + ops.convert_to_numpy(var) + ) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in self.sharded_states + and param.path in self.sharded_states[slot_name] + ): + self.sharded_states[slot_name][param.path][shard_idx] = ( + ops.convert_to_numpy(var) + ) + + def _synchronize_gradients(self, gradients_and_vars): + """Synchronizes gradients across shards based on tensor parallel rules. + + Specifically, it performs an all-reduce operation on gradients of + weights that are split along a "column" dimension in tensor parallelism. + Other gradients are passed through unchanged. + + Args: + gradients_and_vars: The list of (gradient, variable) lists from + all shards. + + Returns: + The list of (gradient, variable) lists after synchronization. + """ + if not self.tensor_parallel_config: + return gradients_and_vars + + rules = self.tensor_parallel_config.state_rules.items() + column_parallel_patterns = { + pattern + for pattern, action in rules + if hasattr(action, "sharding_type") + and action.sharding_type == "column" + } + + if not column_parallel_patterns: + return gradients_and_vars + + num_weights = len(gradients_and_vars[0]) + for i in range(num_weights): + variable = gradients_and_vars[0][i][1] + var_name = getattr(variable, "path", getattr(variable, "name", "")) + + if any( + re.search(pattern, var_name) + for pattern in column_parallel_patterns + ): + grads_to_reduce = [ + g_and_v[i][0] + for g_and_v in gradients_and_vars + if g_and_v[i][0] is not None + ] + if grads_to_reduce: + synced_grad = self._allreduce_gradients(grads_to_reduce)[0] + for shard_idx in range(self.world_size): + gradients_and_vars[shard_idx][i] = ( + synced_grad, + variable, + ) + return gradients_and_vars + + def _allreduce_gradients(self, gradients): + """Performs a mean all-reduce operation on a list of gradients. + + If a distributed backend is available, it uses it. Otherwise, it + falls back to a local mean calculation. + + Args: + gradients: A list of gradients (one from each shard) to be averaged. + + Returns: + A list where each element is the mean of the input gradients. + """ + if not gradients: + return [] + + if distributed_backend.is_multi_device_capable(): + all_reduce_fn = distributed_backend.get_communication_ops()[ + "all_reduce" + ] + numpy_grad = ops.convert_to_numpy(gradients[0]) + synced_numpy = all_reduce_fn(numpy_grad, op="mean") + synced_tensor = ops.convert_to_tensor(synced_numpy) + return [synced_tensor for _ in range(self.world_size)] + + stacked_grads = keras.ops.stack( + [ops.convert_to_tensor(g) for g in gradients], axis=0 + ) + mean_grad = ops.mean(stacked_grads, axis=0) + return [mean_grad for _ in range(len(gradients))] + + def get_weights(self): + """Returns the weights of the base optimizer. + + Returns: + A list of NumPy arrays representing the optimizer's state variables. + """ + return [ + ops.convert_to_numpy(var) for var in self.base_optimizer.variables + ] + + def set_weights(self, weights): + """Sets the weights of the base optimizer. + + Args: + weights: A list of NumPy arrays to set as the optimizer's state. + """ + self.base_optimizer.set_weights(weights) + + def enable_optimizer_state_sharding(self, variables): + """Enables and initializes optimizer state sharding. + + This method is called from `build()`, which is guarded from running + multiple times. We can assume this should always execute. + + Args: + variables: A list of model variables to be optimized. + """ + self.shard_optimizer_states = True + self._variables = variables + self._initialize_sharded_states() + + +class TensorParallelOptimizer(optimizers.Optimizer): + """A Keras Optimizer wrapper for tensor-parallel distributed training. + + This optimizer wraps a standard Keras optimizer (e.g., Adam, SGD) and + delegates the complex tasks of state management and gradient synchronization + to a `CoordinatedOptimizer` instance. It is designed to work with models + that have been sharded for tensor parallelism. + + When `apply_gradients` is called with a list of gradient lists (one for each + model shard), it uses the `CoordinatedOptimizer` to handle synchronization + and state sharding. Otherwise, it behaves like the base optimizer. + + Args: + base_optimizer: A Keras optimizer instance or a string identifier + (e.g., 'adam', 'sgd'). + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism. Defaults to `None`. + + Example: + + ```python + import keras + + # Assume model variables and gradients from 4 shards exist. + # The structure is: list[list[tuple[gradient, variable]]] + trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] + sharded_grads_and_vars = [ + [(keras.ops.ones_like(v), v) for v in trainable_vars] + for _ in range(4) # 4 shards + ] + + # 1. Wrap a standard Keras optimizer. + base_optimizer = keras.optimizers.Adam() + optimizer = TensorParallelOptimizer(base_optimizer, world_size=4) + optimizer.build(trainable_vars) + + # 2. Apply the sharded gradients. + # The optimizer will handle synchronization (e.g., all-reduce) internally. + optimizer.apply_gradients(sharded_grads_and_vars) + ``` + """ + + def __init__( + self, + base_optimizer, + world_size, + distributed_backend="auto", + tensor_parallel_config=None, + ): + """Initializes the TensorParallelOptimizer.""" + if isinstance(base_optimizer, str): + base_optimizer_instance = optimizers.get(base_optimizer) + else: + base_optimizer_instance = base_optimizer + + learning_rate = base_optimizer_instance.learning_rate + if callable(learning_rate): + lr_value = float(ops.convert_to_numpy(learning_rate(0))) + else: + lr_value = float(ops.convert_to_numpy(learning_rate)) + + super().__init__( + learning_rate=lr_value, + name=f"TensorParallel_{base_optimizer_instance.name}", + ) + + self.base_optimizer = base_optimizer_instance + self.world_size = world_size + self.distributed_backend = distributed_backend + self.coordinated_optimizer = CoordinatedOptimizer( + self.base_optimizer, + world_size, + distributed_backend=distributed_backend, + tensor_parallel_config=tensor_parallel_config, + ) + + def apply_gradients(self, grads_and_vars, **kwargs): + """Applies gradients to the model variables. + + If `grads_and_vars` is a list of lists, it's assumed to be from + sharded models, and the `CoordinatedOptimizer` is used. Otherwise, + it calls the `base_optimizer`'s `apply_gradients` directly. + + Args: + grads_and_vars: A list of (gradient, variable) tuples, or a list + of such lists if running in a sharded context. + **kwargs: Additional arguments. `shard_models` can be passed to + provide the list of model shards. + """ + is_sharded_grads = ( + isinstance(grads_and_vars, list) + and grads_and_vars + and isinstance(grads_and_vars[0], list) + ) + if is_sharded_grads: + shard_models = kwargs.get("shard_models", []) + self.coordinated_optimizer.apply_gradients( + grads_and_vars, shard_models + ) + else: + self.base_optimizer.apply_gradients(grads_and_vars) + + def get_config(self): + """Returns the configuration of the optimizer. + + Returns: + A dictionary containing the optimizer's configuration. + """ + from keras.src import saving + + config = super().get_config() + config.pop("learning_rate", None) + config.pop("name", None) + + config.update( + { + "base_optimizer": saving.serialize_keras_object( + self.base_optimizer + ), + "world_size": self.world_size, + "distributed_backend": self.distributed_backend, + } + ) + return config + + def update_step(self, gradient, variable, *args, **kwargs): + """Performs a single optimization step. + + Delegates the update step to the base optimizer if it has a custom + `update_step` implementation, otherwise falls back to the parent + optimizer's logic. + + Args: + gradient: The gradient tensor. + variable: The variable to be updated. + *args: Positional arguments passed to the update function. + **kwargs: Keyword arguments passed to the update function. + """ + if hasattr(self.base_optimizer, "update_step"): + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + + return super().update_step(gradient, variable, *args, **kwargs) + + @classmethod + def from_config(cls, config): + """Creates an optimizer from its configuration. + + Args: + config: A Python dictionary, typically the output of `get_config`. + + Returns: + A `TensorParallelOptimizer` instance. + """ + from keras.src import saving + + base_optimizer_config = config.pop("base_optimizer") + base_optimizer = saving.deserialize_keras_object(base_optimizer_config) + + init_kwargs = { + "world_size": config.get("world_size"), + "distributed_backend": config.get("distributed_backend", "auto"), + "tensor_parallel_config": config.get("tensor_parallel_config"), + } + + return cls(base_optimizer=base_optimizer, **init_kwargs) + + def build(self, variables): + """Builds the optimizer and initializes sharded states. + + This method is called the first time the optimizer is used. It builds + the base optimizer and then triggers the `CoordinatedOptimizer` to + initialize its sharded states. + + Args: + variables: A list of model variables to be optimized. + """ + if self.built: + return + + self.base_optimizer.build(variables) + if variables: + iterations = self.base_optimizer.iterations + original_iterations_val = None + if iterations is not None: + original_iterations_val = ops.convert_to_numpy(iterations.value) + + zero_grads = [ops.zeros_like(v) for v in variables] + self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + + if iterations is not None and original_iterations_val is not None: + iterations.assign(original_iterations_val) + + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) + super().build(variables) + + def get_weights(self): + """Returns the weights of the base optimizer. + + Returns: + A list of NumPy arrays representing the optimizer's state variables. + """ + return self.coordinated_optimizer.get_weights() + + def set_weights(self, weights): + """Sets the weights of the base optimizer. + + Args: + weights: A list of NumPy arrays to set as the optimizer's state. + """ + self.coordinated_optimizer.set_weights(weights) + + @property + def variables(self): + """Returns the list of variables from the base optimizer. + + Returns: + A list of state variables of the base optimizer. + """ + return self.base_optimizer.variables + + @property + def learning_rate(self): + """Provides access to the learning rate of the base optimizer.""" + return self.base_optimizer.learning_rate + + @learning_rate.setter + def learning_rate(self, value): + """Sets the learning rate of the base optimizer.""" + self.base_optimizer.learning_rate = value + + @property + def iterations(self): + """ + Returns the training iteration count directly from the base optimizer. + """ + return self.base_optimizer.iterations diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py deleted file mode 100644 index 38d80a5ec258..000000000000 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ /dev/null @@ -1,180 +0,0 @@ -import numpy as np -import pytest - -import keras -from keras import ops -from keras.src import backend -from keras.src import optimizers -from keras.src import testing -from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - CoordinatedOptimizer, -) -from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - TensorParallelOptimizer, -) - - -@pytest.mark.skipif( - backend.backend() not in ("torch", "jax"), - reason="This test is for JAX/PyTorch backends.", -) -class CoordinatedOptimizerTest(testing.TestCase): - def _get_simple_model(self): - """Creates a simple, uncompiled Keras model.""" - inputs = keras.Input(shape=(10,)) - x = keras.layers.Dense(20, name="dense_1")(inputs) - outputs = keras.layers.Dense(5, name="dense_2")(x) - return keras.Model(inputs, outputs) - - def _get_mock_gradients_and_vars(self, model, world_size): - """Generates mock gradients and variables for N shards.""" - model.build(input_shape=(None, 10)) - variables = model.trainable_variables - grads_and_vars_per_shard = [] - for i in range(world_size): - multiplier = float(i + 1) - gradients = [ - ops.convert_to_tensor( - np.ones_like(v.numpy()) * multiplier, dtype="float32" - ) - for v in variables - ] - grads_and_vars_per_shard.append(list(zip(gradients, variables))) - return grads_and_vars_per_shard - - def test_initialization(self): - """Tests that the optimizer initializes with the correct defaults.""" - base_optimizer = optimizers.Adam() - coord = CoordinatedOptimizer(base_optimizer, world_size=4) - self.assertEqual(coord.base_optimizer, base_optimizer) - self.assertTrue(coord.shard_optimizer_states) - self.assertEqual(coord.sharded_states, {}) - - def test_apply_gradients_with_replicated_states(self): - """Tests that replicated gradients are averaged and applied once.""" - - class AdamWithCallCounter(optimizers.Adam): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.apply_gradients_call_count = 0 - self.received_grads = [] - - def apply_gradients(self, grads_and_vars, *args, **kwargs): - self.apply_gradients_call_count += 1 - self.received_grads = [g for g, v in grads_and_vars] - super().apply_gradients(grads_and_vars, *args, **kwargs) - - world_size = 4 - model = self._get_simple_model() - optimizer = AdamWithCallCounter() - model.build((None, 10)) - mock_grads = self._get_mock_gradients_and_vars(model, world_size) - - coord = CoordinatedOptimizer( - optimizer, - world_size, - shard_optimizer_states=False, - ) - coord.apply_gradients(mock_grads, []) - - self.assertEqual(optimizer.apply_gradients_call_count, 1) - grad_numpy = ops.convert_to_numpy(optimizer.received_grads[0]) - self.assertAllClose( - grad_numpy, - np.ones_like(grad_numpy) * 2.5, - ) - - def test_init_from_string(self): - optimizer = TensorParallelOptimizer("adam", world_size=4) - self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) - - def test_apply_gradients_delegation(self): - """Tests that apply_gradients correctly delegates.""" - world_size = 4 - base_opt = optimizers.Adam() - optimizer = TensorParallelOptimizer(base_opt, world_size) - model = self._get_simple_model() - mock_grads = self._get_mock_gradients_and_vars(model, world_size) - - coord_apply_tracker = {"called": False} - - def coord_apply_mock(*args, **kwargs): - coord_apply_tracker["called"] = True - - optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock - - base_apply_tracker = {"called": False} - - def base_apply_mock(*args, **kwargs): - base_apply_tracker["called"] = True - - optimizer.base_optimizer.apply_gradients = base_apply_mock - - optimizer.apply_gradients(mock_grads, shard_models=[]) - self.assertTrue(coord_apply_tracker["called"]) - self.assertFalse(base_apply_tracker["called"]) - - coord_apply_tracker["called"] = False - unsharded_grads = mock_grads[0] - optimizer.apply_gradients(unsharded_grads) - self.assertTrue(base_apply_tracker["called"]) - self.assertFalse(coord_apply_tracker["called"]) - - def test_build_and_state_sharding(self): - """Tests that the build method correctly initializes sharded states.""" - optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=4) - model = self._get_simple_model() - model.build(input_shape=(None, 10)) - - self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) - optimizer.build(model.trainable_variables) - self.assertTrue(optimizer.built) - - sharded_states = optimizer.coordinated_optimizer.sharded_states - self.assertIn("momentum", sharded_states) - self.assertIn("velocity", sharded_states) - self.assertIn("iterations", sharded_states) - - dense_1_kernel_path = model.get_layer("dense_1").kernel.path - self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) - self.assertEqual( - len(sharded_states["momentum"][dense_1_kernel_path]), 4 - ) - - def test_serialization(self): - world_size = 4 - base_opt = optimizers.Adam(learning_rate=0.1) - optimizer = TensorParallelOptimizer( - base_opt, world_size, distributed_backend=None - ) - - config = optimizer.get_config() - recreated = TensorParallelOptimizer.from_config(config) - - self.assertEqual(recreated.world_size, world_size) - self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) - self.assertIsNone(recreated.distributed_backend) - self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) - - def test_sharding_with_prefixed_variable_names(self): - """Tests that state is correctly mapped with prefixed variable names.""" - inputs = keras.Input(shape=(10,)) - x = keras.layers.Dense(4, name="dense")(inputs) - outputs = keras.layers.Dense(2, name="dense_output")(x) - model = keras.Model(inputs, outputs) - model.build(input_shape=(None, 10)) - - optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=2) - optimizer.build(model.trainable_variables) - - state_to_param = ( - optimizer.coordinated_optimizer._state_variable_to_parameter - ) - self.assertGreater(len(state_to_param), 0) - - dense_output_kernel = model.get_layer("dense_output").kernel - optimizer_name = optimizer.base_optimizer.name - kernel_path = dense_output_kernel.path.replace("/", "_") - momentum_path = f"{optimizer_name}/{kernel_path}_momentum" - - self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding.py b/keras/src/distribution/tensor_parallel/parameter_sharding.py index 30a16e9c63fe..f3282968a414 100644 --- a/keras/src/distribution/tensor_parallel/parameter_sharding.py +++ b/keras/src/distribution/tensor_parallel/parameter_sharding.py @@ -1,35 +1,25 @@ import re -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -import numpy as np - -from keras.src.distribution.tensor_parallel.communications import ( - TensorParallelCommunicator, -) -from keras.src.distribution.tensor_parallel.config import ConfigKeras -from keras.src.distribution.tensor_parallel.state_action_keras import ( - StateActionKeras, -) +from keras.src.backend import distributed_backend class ShardedWeight: - """A wrapper class for a sharded Keras Variable. - - This class holds a shard of a model weight as a `keras.Variable` and - provides an interface similar to the original variable, allowing it to be - seamlessly integrated into the Keras ecosystem. + """A wrapper for a sharded Keras Variable to provide a consistent interface. - Args: - tensor_shard: The tensor slice (shard) of the weight. - name (str): The name for the underlying `keras.Variable`. - trainable (bool): Whether the variable is trainable. + This class wraps a tensor shard in a Keras Variable, making it compatible + with the Keras ecosystem. It exposes common variable properties like name, + shape, and trainable status. """ def __init__(self, tensor_shard, name, trainable=True): + """Initializes the ShardedWeight. + + Args: + tensor_shard: The tensor piece (shard) to be managed by this weight. + name (str): The name for the underlying Keras Variable. + trainable (bool, optional): Whether the variable is trainable. + Defaults to True. + """ import keras self._variable = keras.Variable( @@ -38,42 +28,47 @@ def __init__(self, tensor_shard, name, trainable=True): self.regularizer = None @property - def name(self) -> str: + def name(self): """Returns the name of the underlying variable.""" return self._variable.name @property - def trainable(self) -> bool: + def trainable(self): """Returns whether the variable is trainable.""" return self._variable.trainable @property - def shape(self) -> Tuple[int, ...]: + def shape(self): """Returns the shape of the variable.""" return self._variable.shape @property - def dtype(self) -> any: + def dtype(self): """Returns the dtype of the underlying variable.""" return self._variable.dtype @property def variable(self): - """Provides direct access to the underlying `keras.Variable`.""" + """Provides direct access to the underlying Keras Variable.""" return self._variable - def numpy(self) -> np.ndarray: + @property + def value(self): + """Returns the value of the underlying variable.""" + return self._variable.value + + def numpy(self): """Returns the value of the variable as a NumPy array.""" return self._variable.numpy() - def num_elements(self) -> int: + def num_elements(self): """Returns the total number of elements in the tensor.""" import keras return keras.ops.size(self._variable) - def __repr__(self) -> str: - """Provides a developer-friendly string representation.""" + def __repr__(self): + """Returns a string representation of the ShardedWeight.""" return ( f"" @@ -81,53 +76,47 @@ def __repr__(self) -> str: class ParameterShardingStrategy: - """Manages the sharding of model parameters for tensor parallelism. - - This strategy identifies weights in a Keras model based on configuration - rules, shards them, and stores the sharded weights and metadata. It's - designed to modify a model's parameters without altering its architecture. + """Implements parameter-level sharding for a Keras model. - Args: - world_size (int): The total number of devices (workers) in the - parallel computation group. - rank (int): The unique identifier for the current device (worker), - from 0 to `world_size - 1`. + This strategy shards a model's weights according to a provided configuration + without altering the model's architecture. It identifies weights + that match specific patterns, applies sharding actions to them, and stores + the mapping between original and sharded weights. """ - def __init__(self, world_size: int, rank: int): + def __init__(self, world_size, rank): + """Initializes the ParameterShardingStrategy. + + Args: + world_size (int): The total number of devices in distributed setup. + rank (int): The rank of the current device. + """ self.world_size = world_size self.rank = rank - self.sharded_weights = {} # Maps param name to its sharded tensor - self.original_weights = {} # Stores a copy of original weights - self.weight_mapping = {} # Maps param name to sharding info - self.sharded_weights_by_id = {} # Maps original weight ID to shard - - def shard_model_parameters( - self, - model, - config: ConfigKeras, - communicator: TensorParallelCommunicator, - device_id: Any, - ) -> Tuple[Any, set]: - """Shards model parameters and wraps the model for tensor parallelism. - - This method iterates through the model's parameters, applies sharding - rules defined in the config, and creates a `ParameterShardedModel` - which handles the forward pass with necessary communication primitives. + self.sharded_weights = {} + self.original_weights = {} + self.weight_mapping = {} + self.sharded_weights_by_id = {} + + def shard_model_parameters(self, model, config, device_id): + """Shards model parameters based on a layout configuration. + + This method iterates through the rules in configuration, finds matching + parameters in the model, and applies the specified sharding action. It + then returns a `ParameterShardedModel` wrapper that uses these sharded + weights. Args: - model: The original Keras model to be sharded. - config (ConfigKeras): The configuration object containing sharding - rules (`state_rules` and `output_rules`). - communicator (TensorParallelCommunicator): The communicator for - handling cross-device data transfer (e.g., all-gather). - device_id (Any): The device identifier where the model will run. + model (keras.Model): The Keras model to be sharded. + config (LayoutMap): A configuration object specifying which weights + to shard and how. + device_id: The device identifier for the current process. Returns: - A tuple containing: - - ParameterShardedModel: The new model wrapped for tensor - parallelism. - - set: A set of names of the parameters that were sharded. + tuple: A tuple containing: + - ParameterShardedModel: The wrapped model with sharded + parameters. + - set: A set of names of the parameters that were modified. """ ParameterShardedModel = _define_parameter_sharded_model() @@ -135,13 +124,13 @@ def shard_model_parameters( modified_parameters = set() for pattern, action in config.state_rules.items(): - if isinstance(action, StateActionKeras): + if hasattr(action, "__call__"): matching_params = self._find_matching_parameters(model, pattern) for param_name, param in matching_params: - try: + if hasattr(param, "experimental_ref"): param_id = id(param.experimental_ref()) - except AttributeError: + else: param_id = id(param) if param_id in self.sharded_weights_by_id: @@ -177,7 +166,6 @@ def shard_model_parameters( sharded_model = ParameterShardedModel( original_model=model, sharding_strategy=self, - communicator=communicator, config=config, device_id=device_id, ) @@ -185,13 +173,10 @@ def shard_model_parameters( return sharded_model, modified_parameters def _store_original_weights(self, model): - """Recursively traverses the model and stores original weights.""" + """Recursively finds and stores the original weights of a model.""" from keras.src import layers - def find_weights_recursive( - current_layer: layers.Layer, prefix: str = "" - ): - """Helper to recursively find and store weights.""" + def find_weights_recursive(current_layer, prefix=""): name = current_layer.name full_name = f"{prefix}.{name}" if prefix else name @@ -208,10 +193,9 @@ def find_weights_recursive( for attr_name in dir(current_layer): if attr_name.startswith("__") and attr_name.endswith("__"): continue - try: - attr = getattr(current_layer, attr_name) - except Exception: - continue + + attr = getattr(current_layer, attr_name) + if isinstance(attr, layers.Layer) and attr is not current_layer: find_weights_recursive(attr, full_name) elif isinstance(attr, (list, tuple)): @@ -222,32 +206,23 @@ def find_weights_recursive( for layer in model.layers: find_weights_recursive(layer, prefix="") - def _find_matching_parameters( - self, model, pattern: str - ) -> List[Tuple[str, Any]]: - """Finds model parameters whose names match a given regex pattern. - - This method recursively searches through the model's layers and - sub-layers to find all weights, then filters them based on the pattern. + def _find_matching_parameters(self, model, pattern): + """Finds model parameters that match a given regex pattern. Args: - model: The Keras model to search within. - pattern (str): A regular expression to match against parameter - names. + model (keras.Model): The model to search within. + pattern (str): The regex pattern to match against parameter names. Returns: - A list of tuples, where each tuple contains the parameter's full - name and the parameter object itself. + list: A list of tuples, where each tuple contains the full parameter + name and the corresponding weight object. """ from keras.src import layers matching_params = [] processed_layers = set() - def search_layer_recursive( - current_layer: layers.Layer, prefix: str = "" - ): - """Helper to recursively find matching parameters.""" + def search_layer_recursive(current_layer, prefix=""): if id(current_layer) in processed_layers: return processed_layers.add(id(current_layer)) @@ -273,154 +248,75 @@ def search_layer_recursive( if attr_name.startswith("__") and attr_name.endswith("__"): continue - try: - attr = getattr(current_layer, attr_name) - except Exception: - continue + attr = getattr(current_layer, attr_name) if isinstance(attr, layers.Layer) and attr is not current_layer: search_layer_recursive(attr, full_name) - elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): search_layer_recursive(item, full_name) search_layer_recursive(model, prefix="") - return matching_params - def get_sharded_weight(self, param_name: str) -> Optional[np.ndarray]: - """Retrieves the sharded weight for a given parameter name. - - Args: - param_name (str): The name of the parameter. - - Returns: - The sharded weight as a NumPy array if it exists, otherwise None. - """ - if param_name in self.sharded_weights: - return self.sharded_weights[param_name].numpy() - return None - - def get_weight_info(self, param_name: str) -> Optional[Dict]: - """Retrieves sharding information for a specific parameter. - - Args: - param_name (str): The name of the parameter. - - Returns: - A dictionary containing metadata about the sharding (e.g., - original shape, sharded shape, action) if it exists, - otherwise None. - """ - return self.weight_mapping.get(param_name) - def _define_parameter_sharded_model(): """Factory function to define and return the ParameterShardedModel class. - This approach encapsulates the class definition and avoids potential - circular dependencies, while also keeping the related logic organized. + This approach avoids circular import dependencies by defining the class + that inherits from `keras.src.models.Model` inside a function. Returns: - The `ParameterShardedModel` class. + The ParameterShardedModel class definition. """ from keras.src.models import Model class ParameterShardedModel(Model): - """A Keras Model wrapper for executing a parameter-sharded model. - - This model overrides the `call` and `train_step` methods to inject - the necessary communication operations (e.g., all-reduce, all-gather) - for tensor parallelism during the forward and backward passes. + """A wrapper model that manages sharded parameters for tensor + parallelism. - Args: - original_model (Model): The original, non-sharded Keras model. - sharding_strategy (ParameterShardingStrategy): The strategy - instance that holds the sharded weights and metadata. - communicator (TensorParallelCommunicator): The object responsible - for inter-device communication. - config (ConfigKeras): The configuration with sharding and - communication rules. - device_id (Any): The identifier of the device this model runs on. + This model wraps an existing Keras model, preserving its original + architecture. It overrides the `weights` property and the `call` method + to handle sharded weights and insert the necessary communication + collectives (e.g., AllReduce, AllGather) during the forward pass. """ def __init__( - self, - original_model: Model, - sharding_strategy: ParameterShardingStrategy, - communicator: TensorParallelCommunicator, - config: ConfigKeras, - device_id: Any, + self, original_model, sharding_strategy, config, device_id ): - super().__init__() + """Initializes the ParameterShardedModel. + Args: + original_model: The original, unsharded Keras model. + sharding_strategy: The strategy object + that contains the sharded weights and mappings. + config (LayoutMap): The sharding configuration. + device_id: The device identifier for the current process. + """ + super().__init__() self.original_model = original_model self.sharding_strategy = sharding_strategy self.config = config - self.communicator = communicator self._device = device_id - self._build_and_cache_weights() - if original_model.inputs: self.build(original_model.inputs[0].shape) @property def device(self): - """Returns the device identifier for this model instance.""" + """Returns the device ID associated with this model shard.""" return self._device - def train_step(self, data): - """Custom training step for the parameter-sharded model. - - This method performs a standard forward and backward pass but - adds a crucial gradient synchronization step (`all_reduce`) before - applying gradients. This ensures that each device updates its - local weight shards using gradients computed from all devices. - - Args: - data: A tuple of (x, y, sample_weight) as passed by `fit()`. - - Returns: - A dictionary mapping metric names to their current values. - """ - import tensorflow as tf - - import keras - - x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - - with tf.GradientTape() as tape: - y_pred = self(x, training=True) - loss = self.compute_loss( - x=x, y=y, y_pred=y_pred, sample_weight=sample_weight - ) - - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - - synced_gradients = self.communicator.all_reduce( - gradients, op="sum", axis_name="model" - ) - self.optimizer.apply_gradients( - zip(synced_gradients, trainable_vars) - ) - - self.compiled_metrics.update_state(y, y_pred, sample_weight) - - return {m.name: m.result() for m in self.metrics} - def _build_and_cache_weights(self): - """Constructs a unified list of weights for the model. + """Constructs and caches the definitive list of model weights. - This list includes the custom `ShardedWeight` objects for parameters - that were sharded, and the original `keras.Variable` objects for - those that were not. + This method combines newly created `ShardedWeight` objects with any + original weights that were not sharded (i.e., replicated weights). + This combined list is then cached to be returned by the `weights` + property, ensuring the optimizer sees all trainable parameters. """ weights_list = [] - sharded_weight_ids = set( self.sharding_strategy.sharded_weights_by_id.keys() ) @@ -431,135 +327,131 @@ def _build_and_cache_weights(self): ) in self.sharding_strategy.sharded_weights.items(): weights_list.append(ShardedWeight(sharded_tensor, param_name)) - unsharded_count = 0 for weight in self.original_model.weights: - try: + if hasattr(weight, "experimental_ref"): weight_id = id(weight.experimental_ref()) - except AttributeError: + else: weight_id = id(weight) if weight_id not in sharded_weight_ids: weights_list.append(weight) - unsharded_count += 1 self._weights_list = weights_list @property def weights(self): - """Returns the combined list of sharded and non-sharded weights.""" + """Overrides the base property to return the cached list of weights. + + This list includes both the custom `ShardedWeight` objects and any + unsharded (replicated) weights from the original model. + """ return self._weights_list def call(self, inputs, training=None, mask=None): - """Defines the forward pass of the model. + """Executes the forward pass of the model with sharded parameters. - This method executes the layers of the original model sequentially. - After each layer's execution, it checks if an output communication - rule applies (e.g., for row-parallel or column-parallel layers) - and triggers the corresponding communication operation. + This method manually reconstructs the forward pass of original + model's computation graph. It propagates tensors from one layer to + the next, and after layer's computation, it checks if communication + collective needs to be applied to the output tensor based on the + sharding configuration. Args: inputs: Input tensor(s). - training (bool): Indicates if the model is in training mode. - mask: A mask or list of masks. + training (bool, optional): Indicates whether the model is in + training mode. Defaults to None. + mask: Mask tensor(s). Defaults to None. Returns: - The output tensor of the model. + The final output tensor(s) of the model. """ from keras.src import layers tensor_cache = {} - current_tensor = inputs + + if isinstance(inputs, dict): + for inp_tensor in self.original_model.inputs: + tensor_cache[id(inp_tensor)] = inputs[inp_tensor.name] + else: + tensor_cache[id(self.original_model.inputs[0])] = inputs for layer in self.original_model.layers: if isinstance(layer, layers.InputLayer): continue - if isinstance(layer, layers.Add): - try: - if "feedforward_output" in layer.name: - residual_source_name = layer.name.replace( - "feedforward_output", "self_attention_output" - ) - elif "self_attention_output" in layer.name: - residual_source_name = layer.name.replace( - "self_attention_output", "input_layer_norm" - ) - else: - residual_source_name = None - - if ( - residual_source_name - and residual_source_name in tensor_cache - ): - layer_inputs = [ - current_tensor, - tensor_cache[residual_source_name], - ] - else: - layer_inputs = [current_tensor, current_tensor] - except Exception: - layer_inputs = [current_tensor, current_tensor] - else: - layer_inputs = current_tensor + layer_inputs = [] + for node in layer._inbound_nodes: + for symbolic_input_tensor in node.input_tensors: + layer_inputs.append( + tensor_cache[id(symbolic_input_tensor)] + ) - if ( - "attention_output" in layer.name - or "feedforward_output" in layer.name - ): - tensor_cache[layer.name] = current_tensor + if len(layer_inputs) == 1: + layer_inputs = layer_inputs[0] current_tensor = layer(layer_inputs, training=training) + tensor_cache[id(layer.output)] = current_tensor layer_path = layer.path - output_rule = None for pattern, rule in self.config.output_rules.items(): if re.search(pattern, layer_path): output_rule = rule.get(0) break - if output_rule: current_tensor = self._apply_communication( current_tensor, layer.name, output_rule ) + tensor_cache[id(layer.output)] = current_tensor + + final_outputs = [] + for symbolic_output in self.original_model.outputs: + final_outputs.append(tensor_cache[id(symbolic_output)]) + + if len(final_outputs) == 1: + return final_outputs[0] + return final_outputs - return current_tensor + def _apply_communication(self, sharded_output, layer_name, rule_str): + """Applies a collective communication operation to a tensor. - def _apply_communication(self, sharded_output, layer_name, rule): - """Applies a communication primitive based on a rule. + This method uses the distributed backend to perform operations like + AllReduce (for summing partial results in row-parallel layouts) or + AllGather (for combining results in column-parallel layouts). Args: - sharded_output: The output tensor from a layer. - layer_name (str): The name of the layer. - rule: The communication rule from the config. + sharded_output: The tensor to apply the communication op to. + layer_name (str): The name of the layer producing the output. + rule_str (str): A string from config describing the operation + (e.g., 'allreduce sum', 'allgather -1'). Returns: The tensor after the communication operation has been applied. """ - op_name = str(rule).lower() + comm_ops = distributed_backend.get_communication_ops() - if "sum" in op_name or "allreduce" in op_name: - return self.communicator.forward_row_parallel( + if "sum" in rule_str or "allreduce" in rule_str: + return comm_ops["all_reduce"]( sharded_output, op="sum", axis_name="model" ) - - elif "gather" in op_name: - try: - dim = int(op_name.split(" ")[-1]) - except (ValueError, IndexError): + elif "gather" in rule_str: + parts = rule_str.split(" ") + last_part = parts[-1] + if len(parts) > 1 and ( + last_part.isdigit() + or (last_part.startswith("-") and last_part[1:].isdigit()) + ): + dim = int(last_part) + else: dim = -1 - return self.communicator.forward_column_parallel( - sharded_output, dim=dim, axis_name="model" + return comm_ops["all_gather"]( + sharded_output, axis=dim, axis_name="model" ) - - elif hasattr(rule, "__call__"): - return rule(sharded_output) - else: return sharded_output def get_config(self): - """Serializes the model's configuration.""" + """Returns the configuration of the original model.""" return self.original_model.get_config() @classmethod @@ -570,100 +462,26 @@ def from_config(cls, config, custom_objects=None): return ParameterShardedModel -def make_parameter_sharded_model( - module, config: ConfigKeras, rank: int, world_size: int, device_id: Any -) -> Tuple[Any, set]: +def make_parameter_sharded_model(module, config, rank, world_size, device_id): """Creates a parameter-sharded version of a Keras model. - This is a high-level factory function that orchestrates the creation of - the communicator, the sharding strategy, and the final sharded model. + This is the main entry point for applying parameter sharding. It initializes + the sharding strategy and uses it to transform the given model. Args: - module: The Keras model to be sharded. - config (ConfigKeras): Configuration object with sharding rules. - rank (int): The rank of the current process. - world_size (int): The total number of processes. - device_id (Any): The device on which the model will be placed. + module (keras.Model): The Keras model to shard. + config (LayoutMap): The configuration defining the sharding rules. + rank (int): The rank of the current device. + world_size (int): The total number of devices. + device_id: The identifier for the current device. Returns: - A tuple containing: - - The newly created `ParameterShardedModel`. - - A set of names for the parameters that were modified. + tuple: A tuple containing: + - ParameterShardedModel: The new, sharded model wrapper. + - set: A set of names of the parameters that were sharded. """ - communicator = TensorParallelCommunicator(world_size=world_size, rank=rank) sharding_strategy = ParameterShardingStrategy(world_size, rank) - sharded_model, modified_parameters = ( - sharding_strategy.shard_model_parameters( - module, config, communicator, device_id - ) + sharding_strategy.shard_model_parameters(module, config, device_id) ) - return sharded_model, modified_parameters - - -def apply_parameter_sharding_to_existing_model( - model, config: ConfigKeras, rank: int, world_size: int -): - """Applies parameter sharding directly to an existing model instance. - - This function modifies a model in-place. Instead of returning a new - wrapped model, it shards the weights and attaches the sharding strategy - to the original model object. This is useful when the model's execution - logic is handled externally. - - Args: - model: The Keras model to modify. - config (ConfigKeras): Configuration object with sharding rules. - rank (int): The rank of the current process. - world_size (int): The total number of processes. - - Returns: - The modified model with an attached `_tensor_parallel_sharding` - strategy attribute. - """ - - sharding_strategy = ParameterShardingStrategy(world_size, rank) - for pattern, action in config.state_rules.items(): - if isinstance(action, StateActionKeras): - matching_params = sharding_strategy._find_matching_parameters( - model, pattern - ) - - for param_name, param in matching_params: - try: - param_id = id(param.experimental_ref()) - except AttributeError: - param_id = id(param) - - if param_id in sharding_strategy.sharded_weights_by_id: - sharding_strategy.sharded_weights[param_name] = ( - sharding_strategy.sharded_weights_by_id[param_id] - ) - existing_param_name = next( - k - for k, v in sharding_strategy.sharded_weights.items() - if v - is sharding_strategy.sharded_weights_by_id[param_id] - ) - sharding_strategy.weight_mapping[param_name] = ( - sharding_strategy.weight_mapping[existing_param_name] - ) - continue - - sharded_param = action(param, rank) - - sharding_strategy.sharded_weights[param_name] = sharded_param - sharding_strategy.sharded_weights_by_id[param_id] = ( - sharded_param - ) - - sharding_strategy.weight_mapping[param_name] = { - "original_shape": param.shape, - "sharded_shape": sharded_param.shape, - "action": action, - } - - model._tensor_parallel_sharding = sharding_strategy - - return model diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding_test.py b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py index dc686436af97..c39507c77365 100644 --- a/keras/src/distribution/tensor_parallel/parameter_sharding_test.py +++ b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py @@ -9,23 +9,18 @@ import keras from keras import distribution from keras.src import backend -from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.parameter_sharding import ( ShardedWeight, ) from keras.src.distribution.tensor_parallel.parameter_sharding import ( make_parameter_sharded_model, ) -from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split from keras.src.testing import TestCase -@pytest.mark.skipif( - backend.backend() not in ("torch", "jax"), - reason="This test is for JAX/PyTorch backends.", -) def _create_simple_mlp(): - """Creates a simple, unsharded Keras MLP model for testing.""" inputs = keras.Input(shape=(16,), name="input") x = keras.layers.Dense(32, use_bias=True, name="up_proj")(inputs) x = keras.layers.Activation("relu")(x) @@ -33,6 +28,10 @@ def _create_simple_mlp(): return keras.Model(inputs=inputs, outputs=outputs, name="simple_mlp") +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is for the JAX backend only.", +) class ParameterShardingTest(TestCase): def setUp(self): super().setUp() @@ -52,12 +51,12 @@ def setUp(self): self.original_model = _create_simple_mlp() self.original_model.build(input_shape=(None, 16)) - self.tp_config = ConfigKeras( + self.tp_config = LayoutMap( state_rules={ - re.escape("simple_mlp.up_proj.kernel"): SplitKeras( + re.escape("simple_mlp.up_proj.kernel"): Split( self.world_size, dim=1 ), - re.escape("simple_mlp.down_proj.kernel"): SplitKeras( + re.escape("simple_mlp.down_proj.kernel"): Split( self.world_size, dim=0 ), }, diff --git a/keras/src/distribution/tensor_parallel/state_action_keras_test.py b/keras/src/distribution/tensor_parallel/state_action_keras_test.py deleted file mode 100644 index a6947958a4aa..000000000000 --- a/keras/src/distribution/tensor_parallel/state_action_keras_test.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -import keras -from keras.src import backend -from keras.src import testing -from keras.src.distribution.tensor_parallel.state_action_keras import ( - GatherKeras, -) -from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras -from keras.src.distribution.tensor_parallel.state_action_keras import SumKeras - - -@pytest.mark.skipif( - backend.backend() not in ("torch", "jax"), - reason="This test is for JAX/PyTorch backends.", -) -class TestStateActions(testing.TestCase): - """Test suite for tensor distribution state actions.""" - - def test_split_keras_even_split(self): - """Tests SplitKeras with a tensor that divides evenly.""" - world_size = 4 - tensor = keras.ops.reshape( - keras.ops.arange(16, dtype="float32"), (4, 4) - ) - - action_row = SplitKeras(world_size=world_size, dim=0) - shards_row = [action_row(tensor, rank=i) for i in range(world_size)] - - self.assertEqual(shards_row[0].shape, (1, 4)) - self.assertAllClose(shards_row[0], tensor[0:1, :]) - self.assertAllClose(shards_row[3], tensor[3:4, :]) - - reconstructed_row = action_row.undo(shards_row) - self.assertAllClose(reconstructed_row, tensor) - - action_col = SplitKeras(world_size=world_size, dim=1) - shards_col = [action_col(tensor, rank=i) for i in range(world_size)] - - self.assertEqual(shards_col[0].shape, (4, 1)) - self.assertAllClose(shards_col[0], tensor[:, 0:1]) - self.assertAllClose(shards_col[2], tensor[:, 2:3]) - - reconstructed_col = action_col.undo(shards_col) - self.assertAllClose(reconstructed_col, tensor) - - def test_split_keras_uneven_split(self): - """Tests SplitKeras with a tensor that does not divide evenly.""" - world_size = 3 - tensor = keras.ops.reshape( - keras.ops.arange(40, dtype="float32"), (4, 10) - ) - - action = SplitKeras(world_size=world_size, dim=1) - shards = [action(tensor, rank=i) for i in range(world_size)] - - self.assertEqual(shards[0].shape, (4, 4)) - self.assertEqual(shards[1].shape, (4, 3)) - self.assertEqual(shards[2].shape, (4, 3)) - - self.assertAllClose(shards[0], tensor[:, 0:4]) - self.assertAllClose(shards[1], tensor[:, 4:7]) - self.assertAllClose(shards[2], tensor[:, 7:10]) - - reconstructed = action.undo(shards) - self.assertAllClose(reconstructed, tensor) - - def test_split_keras_sharding_type_inference(self): - """Tests that `sharding_type` correctly infers the split dimension.""" - action_row = SplitKeras(world_size=2, dim=-1, sharding_type="row") - self.assertEqual(action_row.dim, 0) - - action_col = SplitKeras(world_size=2, dim=-1, sharding_type="column") - self.assertEqual(action_col.dim, 1) - - def test_gather_keras(self): - """Tests the GatherKeras action.""" - world_size = 4 - action = GatherKeras(world_size=world_size, dim=0) - tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") - - processed_tensor = action(tensor, rank=0) - self.assertAllClose(processed_tensor, tensor) - - tensors_to_gather = [ - keras.ops.ones((2, 2)), - keras.ops.zeros((2, 2)), - keras.ops.ones((2, 2)), - ] - reconstructed = action.undo(tensors_to_gather) - expected = keras.ops.concatenate(tensors_to_gather, axis=0) - self.assertAllClose(reconstructed, expected) - - def test_sum_keras(self): - """Tests the SumKeras action.""" - world_size = 2 - action = SumKeras(world_size=world_size) - tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") - - processed_tensor = action(tensor, rank=0) - self.assertAllClose(processed_tensor, tensor) - - tensors_to_sum = [ - keras.ops.full((2, 3), 5.0), - keras.ops.full((2, 3), 10.0), - ] - reconstructed = action.undo(tensors_to_sum) - expected = keras.ops.full((2, 3), 15.0) - self.assertAllClose(reconstructed, expected) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py new file mode 100644 index 000000000000..6841e4d01a36 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -0,0 +1,154 @@ +import keras + + +class LayoutAction: + def __call__(self, tensor, rank): + """Applies the distribution action to a tensor for a specific worker. + + Args: + tensor: The input tensor to be distributed. + rank: The integer rank of the current worker/device. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + A shard or transformation of the input tensor specific to the given + rank. + """ + raise NotImplementedError + + def undo(self, tensors): + """Reverses the distribution action, reconstructing the original tensor. + + Args: + tensors: A sequence of tensor shards, one from each worker. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + The reconstructed, single tensor. + """ + raise NotImplementedError + + +class _ConcatenateMixin: + """A mixin class providing a common `undo` method via concatenation. + + This class is intended to be used as a mixin for `LayoutAction` subclasses + that can be undone by simple concatenation along a specified axis. + """ + + def undo(self, tensors): + if self.dim == -1: + dim = keras.ops.ndim(tensors[0]) - 1 + else: + dim = self.dim + return keras.ops.concatenate(tensors, axis=dim) + + +class Split(_ConcatenateMixin, LayoutAction): + """Splits a tensor into shards along a specified dimension. + + This is an internal utility used by a higher-level distribution API. + It implements sharding by slicing a tensor along one of its axes. + It handles cases where the dimension size is not perfectly divisible by the + number of workers by distributing the remainder elements one by one to the + first few workers. + + The `undo` operation is provided by the `_ConcatenateMixin`. + """ + + def __init__(self, world_size, dim, sharding_type="auto"): + """Initializes the Split action. + + Args: + world_size: The total number of workers/shards. + dim: The dimension along which to split the tensor. If -1, the + last dimension is used. + sharding_type: If `dim` is -1, this can be 'row' (dim=0) or + 'column' (dim=1) to infer the split axis for 2D tensors. + Defaults to "auto". + """ + super().__init__() + self.world_size = world_size + self.dim = dim + self.sharding_type = sharding_type + + if dim == -1 and sharding_type != "auto": + if sharding_type == "row": + self.dim = 0 + elif sharding_type == "column": + self.dim = 1 + + def __call__(self, tensor, rank): + """Splits the tensor and returns the shard corresponding to the rank. + + This method calculates the correct slice of the tensor for a given + worker rank, handling uneven distributions gracefully. + + Args: + tensor: The full tensor to be sharded. + rank: The rank of the worker for which to get the shard. + + Returns: + A tensor shard corresponding to the given rank. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensor) - 1 + else: + dim = self.dim + + total_size = tensor.shape[dim] + split_size = total_size // self.world_size + remainder = total_size % self.world_size + + start_idx = rank * split_size + min(rank, remainder) + end_idx = start_idx + split_size + (1 if rank < remainder else 0) + + slices = [slice(None)] * keras.ops.ndim(tensor) + slices[dim] = slice(start_idx, end_idx) + return tensor[tuple(slices)] + + +class LayoutMap: + """A mapping that defines layout rules for model states and outputs. + + This is an internal configuration object used to hold layout rules for + how model variables and layer outputs should be distributed across a set + of devices. It acts as a container for `LayoutAction` instances. + + Attributes: + state_rules: A dictionary mapping variable names or patterns to + `LayoutAction` instances. + output_rules: A dictionary mapping layer output names or + patterns to `LayoutAction` instances. + """ + + def __init__(self, state_rules, output_rules): + """Initializes the LayoutMap. + + Args: + state_rules: A dictionary of distribution rules for model states. + output_rules: A dictionary of distribution rules for model outputs. + """ + self.state_rules = state_rules + self.output_rules = output_rules + + def create_collective_ops(self, devices): + """Creates the necessary collective communication operations. + + This method is a placeholder for backend-specific logic that would + translate the layout rules into actual communication primitives + (e.g., all-gather, reduce-scatter). + + Args: + devices: A sequence of device identifiers. + + Returns: + The `LayoutMap` instance itself, allowing for method chaining. + """ + return self From 6b2efcc36ab207107431da5232931dff9da1b843 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 16 Oct 2025 01:18:03 +0530 Subject: [PATCH 7/8] refactor --- keras/src/backend/__init__.py | 1 + keras/src/backend/torch/distributed_backend_test.py | 2 +- keras/src/backend/torch/distribution_lib_test.py | 8 ++------ 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index dc93944dfd47..ecae3a6c7631 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -42,6 +42,7 @@ elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable + elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable diff --git a/keras/src/backend/torch/distributed_backend_test.py b/keras/src/backend/torch/distributed_backend_test.py index cbf4766b1c9c..70ba5caab1ae 100644 --- a/keras/src/backend/torch/distributed_backend_test.py +++ b/keras/src/backend/torch/distributed_backend_test.py @@ -7,7 +7,7 @@ @pytest.mark.skipif( backend.backend() != "torch", - reason="Jax Backend specific test", + reason="Torch Backend specific test", ) class TestPytorchDistributedFunctions: """Unit tests for the PyTorch distributed backend standalone functions.""" diff --git a/keras/src/backend/torch/distribution_lib_test.py b/keras/src/backend/torch/distribution_lib_test.py index 2897b022a0d4..bf4c20403b51 100644 --- a/keras/src/backend/torch/distribution_lib_test.py +++ b/keras/src/backend/torch/distribution_lib_test.py @@ -11,10 +11,6 @@ from keras.src.distribution import TensorLayout -@pytest.mark.skipif( - backend.backend() != "torch", - reason="Backend specific test", -) def setup_torch_distributed(): """ A fixture to initialize the distributed process group if not already done. @@ -32,8 +28,8 @@ def setup_torch_distributed(): @pytest.mark.skipif( - not torch.distributed.is_available(), - reason="PyTorch distributed components are not available.", + backend.backend() != "torch", + reason="Backend specific test", ) class TestTorchDistributionLibLive: """ From f126d013abb33a7e7c4b937e5124f6f807ee2c6a Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 16 Oct 2025 01:57:27 +0530 Subject: [PATCH 8/8] removing contents from pr1 and 2 --- keras/src/backend/jax/distributed_backend.py | 95 --- .../tensor_parallel/autoconfig.py | 282 -------- .../tensor_parallel/coordinated_optimizer.py | 653 ------------------ .../tensor_parallel/tensor_layout.py | 154 ----- 4 files changed, 1184 deletions(-) delete mode 100644 keras/src/backend/jax/distributed_backend.py delete mode 100644 keras/src/distribution/tensor_parallel/autoconfig.py delete mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer.py delete mode 100644 keras/src/distribution/tensor_parallel/tensor_layout.py diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py deleted file mode 100644 index 3ed50d756250..000000000000 --- a/keras/src/backend/jax/distributed_backend.py +++ /dev/null @@ -1,95 +0,0 @@ -import jax -import jax.lax as lax - - -def get_device_info(): - """Retrieves information about the available JAX devices. - - This function queries the JAX backend to identify the type and number - of available computational devices (e.g., CPU, GPU, TPU). - - Returns: - dict: A dictionary containing the backend name ('jax'), a list of - device string representations, and the total count of devices. - """ - available_devices = jax.devices() - return { - "backend": "jax", - "devices": [str(d) for d in available_devices], - "device_count": len(available_devices), - } - - -def is_multi_device_capable(): - """Checks if more than one JAX device is available for computation. - - Returns: - bool: True if the local JAX environment has more than one device, - False otherwise. - """ - return jax.local_device_count() > 1 - - -def get_communication_ops(): - """Provides a dictionary of JAX collective communication operations. - - Returns: - dict: A dictionary mapping operation names (e.g., 'all_reduce') to their - corresponding JAX implementation functions. - """ - - def all_reduce(x, op="sum", axis_name="model"): - """Reduces a tensor across a device mesh axis using a collective. - - This function assumes it is called within a `pjit` context that has a - device mesh with the specified `axis_name`. It performs a collective - reduction operation (like sum or mean) across all devices mapped to - that axis. - - Args: - x (jax.Array): The input JAX array (tensor) on the local device. - op (str, optional): The reduction operation to perform. Supported - values are 'sum' and 'mean'. Defaults to 'sum'. - axis_name (str, optional): The name of the mapped axis in the device - mesh over which to communicate. Defaults to 'model'. - - Returns: - jax.Array: The reduced JAX array, which is identical across all - devices participating in the reduction. - """ - if op == "sum": - return lax.psum(x, axis_name=axis_name) - elif op == "mean": - return lax.pmean(x, axis_name=axis_name) - else: - raise ValueError( - f"Unsupported reduction operation: {op}. " - "Supported options are 'sum' and 'mean'." - ) - - def all_gather(x, axis, axis_name="model"): - """Gathers and concatenates tensors from all devices across a mesh axis. - - This function assumes it is called within a `pjit` context. It takes - the local shard `x` from each device along the `axis_name` of the mesh - and concatenates them along the specified tensor `axis` to form a - single, larger tensor that is then replicated on all - participating devices. - - Args: - x (jax.Array): The input JAX array (tensor) shard on local device. - axis (int): The tensor axis along which to concatenate the gathered - shards. - axis_name (str, optional): The name of the mesh axis to gather - from. Defaults to 'model'. - - Returns: - jax.Array: The full, gathered JAX array, which is identical across - all devices participating in the gather. - """ - return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) - - return { - "all_reduce": all_reduce, - "all_gather": all_gather, - } diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py deleted file mode 100644 index 708d6d603cc6..000000000000 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ /dev/null @@ -1,282 +0,0 @@ -from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap -from keras.src.distribution.tensor_parallel.tensor_layout import Split - - -def analyze_dense_layer_directly(layer, module, prefix): - """Analyzes a Keras Dense layer to classify its sharding strategy. - - This function inspects the input and output dimensions of a Dense layer - to determine if it functions as an expansion layer ("up-projection"), a - contraction layer ("down-projection"), or neither ("generic_dense"). This - classification is a heuristic commonly used to apply tensor parallelism - in Transformer-based models, such as in an MLP block where an up-projection - is followed by a down-projection. - - Args: - layer: The Keras `layers.Dense` instance to analyze. - module: The parent module containing the layer (currently unused). - prefix (str): The name prefix for the layer in the model hierarchy - (currently unused). - - Returns: - str: A string classifying the layer as 'up_projection', - 'down_projection', or 'generic_dense'. - """ - from keras.src import layers - - if not isinstance(layer, layers.Dense): - return "generic_dense" - - input_dim = None - output_dim = None - - if hasattr(layer, "kernel") and layer.kernel is not None: - kernel_shape = layer.kernel.shape - if len(kernel_shape) == 2: - input_dim = kernel_shape[0] - output_dim = kernel_shape[1] - - if input_dim is None or output_dim is None: - if hasattr(layer, "units"): - output_dim = layer.units - else: - return "generic_dense" - - if ( - hasattr(layer, "input_shape") - and layer.input_shape - and len(layer.input_shape) > 1 - ): - input_dim = layer.input_shape[-1] - else: - return "generic_dense" - - if not input_dim or not output_dim: - return "generic_dense" - - expansion_threshold = 1.5 - is_expansion = output_dim > input_dim * expansion_threshold - is_contraction = input_dim > output_dim * expansion_threshold - - if is_expansion: - return "up_projection" - elif is_contraction: - return "down_projection" - else: - return "generic_dense" - - -def _find_and_shard_layers( - current_layer, - prefix, - module, - world_size, - state_rules, - output_rules, - processed_layers, -): - """Recursively traverses the model graph to apply sharding rules. - - This function walks through all nested layers of a given Keras model or - layer. For each encountered layer, it determines an appropriate tensor - parallelism strategy and populates the `state_rules` and `output_rules` - dictionaries with the corresponding sharding actions. It uses a set of - processed layer IDs to avoid redundant processing of shared layers. - - The sharding logic is as follows: - - `Dense` layers are sharded based on their classification (up/down proj). - - Up-projections are split along the column axis (output features). - - Down-projections are split along the row axis (input features). - - `EinsumDense` layers in attention blocks are sharded similarly. - - `Embedding` layers are sharded column-wise for vocabulary parallelism. - - Normalization layers are ignored (replicated on all devices). - - Args: - current_layer: The Keras layer currently being processed. - prefix (str): The hierarchical name prefix for the `current_layer`. - module: The top-level Keras model or layer being configured. - world_size (int): The total number of devices for sharding. - state_rules (Dict[str, Any]): A dictionary to be populated with rules - for sharding layer weights (state). Keys are regex patterns - matching weight names, values are `SplitKeras` actions. - output_rules (Dict[str, Any]): A dictionary to be populated with rules - for handling layer outputs. Keys are regex patterns matching layer - names, values describe the communication op (e.g., 'allreduce'). - processed_layers (Set[int]): A set of `id()`s of layers that have - already been processed to prevent cycles and redundant work. - """ - from keras.src import layers - - if id(current_layer) in processed_layers: - return - processed_layers.add(id(current_layer)) - - name = current_layer.name - full_name = f"{prefix}.{name}" if prefix else name - - if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly( - current_layer, module, full_name - ) - - if mlp_type == "up_projection": - state_rules[f"^{full_name}.kernel$"] = Split( - world_size, 1, "column" - ) - if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = Split( - world_size, 0, "column" - ) - output_rules[f"^{full_name}$"] = {0: "gather"} - - elif mlp_type == "down_projection": - state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") - output_rules[f"^{full_name}$"] = {0: "allreduce"} - - else: - state_rules[f"^{full_name}.kernel$"] = Split( - world_size, 1, "column" - ) - if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = Split( - world_size, 0, "column" - ) - output_rules[f"^{full_name}$"] = {0: "gather -1"} - return - - elif isinstance(current_layer, layers.EinsumDense): - if "attention_output" in full_name: - state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - pass - output_rules[f"^{full_name}$"] = {0: "allreduce"} - else: - state_rules[f"^{full_name}.kernel$"] = Split( - world_size, 1, "column" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - state_rules[f"^{full_name}.bias$"] = Split( - world_size, 0, "column" - ) - output_rules[f"^{full_name}$"] = {0: "gather -1"} - return - - elif isinstance(current_layer, (layers.Embedding,)): - if hasattr(current_layer, "token_embedding") or hasattr( - current_layer, "position_embedding" - ): - pass - else: - weight_name = None - if hasattr(current_layer, "embeddings"): - weight_name = "embeddings" - elif hasattr(current_layer, "position_embeddings"): - weight_name = "position_embeddings" - - if weight_name: - state_rules[f"^{full_name}\\..*{weight_name}$"] = Split( - world_size, 1, "column" - ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} - return - - elif isinstance( - current_layer, - ( - layers.LayerNormalization, - layers.BatchNormalization, - layers.GroupNormalization, - ), - ): - return - - if hasattr(current_layer, "layers") and current_layer.layers: - for sub_layer in current_layer.layers: - _find_and_shard_layers( - sub_layer, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, - ) - - for attr_name in dir(current_layer): - if attr_name.startswith("__") and attr_name.endswith("__"): - continue - if hasattr(current_layer, attr_name): - attr = getattr(current_layer, attr_name) - - if isinstance(attr, layers.Layer) and attr is not current_layer: - _find_and_shard_layers( - attr, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, - ) - elif isinstance(attr, (list, tuple)): - for item in attr: - if isinstance(item, layers.Layer): - _find_and_shard_layers( - item, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, - ) - - -def get_default_config_keras(module, device_ids): - """Generates default tensor parallelism sharding configuration for a model. - - This function serves as entry point for automatically creating a sharding - plan for a given Keras model or layer. It initializes the rule dictionaries - and starts the recursive layer traversal to populate them based on a default - set of heuristics for common architectures like Transformers. - - Example: - ```python - model = MyTransformerModel() - device_ids = ["gpu:0", "gpu:1"] - sharding_config = get_default_config_keras(model, device_ids) - # sharding_config can now be used to distribute the model - ``` - - Args: - module: The Keras `Model` or `Layer` to generate a config for. - device_ids (Sequence[str]): A sequence of device IDs (e.g., - ["gpu:0", "gpu:1"]) across which the model will be sharded. - - Returns: - ConfigKeras: A configuration object containing the generated sharding - rules for model weights (`state_rules`) and layer outputs - (`output_rules`). - """ - world_size = len(device_ids) - state_rules = {} - output_rules = {} - processed_layers = set() - - _find_and_shard_layers( - current_layer=module, - prefix="", - module=module, - world_size=world_size, - state_rules=state_rules, - output_rules=output_rules, - processed_layers=processed_layers, - ) - - return LayoutMap(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py deleted file mode 100644 index 62039e2e121f..000000000000 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ /dev/null @@ -1,653 +0,0 @@ -import re - -import numpy as np - -import keras -from keras.src import ops -from keras.src import optimizers -from keras.src.backend import distributed_backend - - -class CoordinatedOptimizer: - """Manages an optimizer's state for distributed training. - - This class is an internal coordinator that handles the complexities of - sharding optimizer states across multiple devices (shards) and - synchronizing gradients according to tensor parallelism rules. It is not - intended to be used directly by the end-user but is a core component of - the `TensorParallelOptimizer`. - - Args: - base_optimizer: The Keras optimizer instance - (e.g., `keras.optimizers.Adam`) whose state will be managed. - world_size: The total number of devices/processes in the distributed - setup. - distributed_backend: The distributed communication backend to use. - Defaults to "auto". - rank: The rank of the current process. Defaults to 0. - shard_optimizer_states: If `True`, the optimizer's state variables - (e.g., momentum, velocity) will be partitioned across `world_size` - devices. Defaults to `True`. - tensor_parallel_config: An optional configuration object that defines - rules for tensor parallelism, such as which gradients to - all-reduce. Defaults to `None`. - """ - - def __init__( - self, - base_optimizer, - world_size, - distributed_backend="auto", - rank=0, - shard_optimizer_states=True, - tensor_parallel_config=None, - ): - """Initializes the CoordinatedOptimizer.""" - self.base_optimizer = base_optimizer - self.world_size = world_size - self.shard_optimizer_states = shard_optimizer_states - self.tensor_parallel_config = tensor_parallel_config - self.sharded_states = {} - self._state_variable_to_parameter = {} - self._variables = None - self._variable_to_slot_name = {} - - def _initialize_sharded_states(self): - """ - Partitions the optimizer's state variables across shards by inspecting - the variables created by the base optimizer. - """ - if not self.shard_optimizer_states or not self.base_optimizer.built: - return - - self.sharded_states = {} - self._state_variable_to_parameter = {} - self._variable_to_slot_name = {} - opt_name = self.base_optimizer.name - - normalized_params = sorted( - [(p.path.replace("/", "_"), p) for p in self._variables], - key=lambda x: len(x[0]), - reverse=True, - ) - - for state_var in self.base_optimizer.variables: - if state_var is self.base_optimizer.iterations: - continue - - path_parts = state_var.path.split("/") - if len(path_parts) != 2 or path_parts[0] != opt_name: - continue - - state_suffix = path_parts[1] - - found_param = None - slot_name = None - for norm_param_path, param in normalized_params: - if state_suffix.startswith(norm_param_path): - found_param = param - slot_suffix = state_suffix[len(norm_param_path) :] - slot_name = slot_suffix.strip("_") - break - - if found_param is not None and slot_name is not None: - self._state_variable_to_parameter[state_var.path] = found_param - self._variable_to_slot_name[state_var.path] = slot_name - - sharding_dim = 0 - if self.tensor_parallel_config: - norm_param_name = found_param.path.replace("/", ".") - for p, a in self.tensor_parallel_config.state_rules.items(): - if re.search(p, norm_param_name) and hasattr(a, "dim"): - sharding_dim = a.dim - break - - partitioned_state = self._partition_state( - state_var, dim=sharding_dim - ) - self.sharded_states.setdefault(slot_name, {})[ - found_param.path - ] = partitioned_state - - if self.base_optimizer.iterations is not None: - self.sharded_states["iterations"] = self._partition_state( - self.base_optimizer.iterations, dim=0 - ) - - def _partition_state(self, state_variable, dim): - """Splits a single state variable numpy array into chunks. - - If the variable cannot be split along the given dimension, it is - replicated across all shards. - - Args: - state_variable: The optimizer state variable. - dim: The dimension along which to partition the variable. - - Returns: - A list of NumPy arrays, where each array is a partition of the - original state variable for a specific shard. - """ - state_array = ops.convert_to_numpy(state_variable) - if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: - return np.array_split(state_array, self.world_size, axis=dim) - else: - return [np.copy(state_array) for _ in range(self.world_size)] - - def apply_gradients(self, grads_and_vars, shard_models): - """ - Applies gradients to the model variables by first synchronizing them - and then applying them using either sharded or replicated optimizer - states. - - Args: - grads_and_vars: A list of (gradient, variable) lists from all - shards. - shard_models: A list of the sharded model instances. - """ - synchronized_gradients = self._synchronize_gradients(grads_and_vars) - - if self.shard_optimizer_states: - self._apply_gradients_with_sharded_states( - synchronized_gradients, shard_models - ) - else: - self._apply_gradients_with_replicated_states( - synchronized_gradients, shard_models - ) - - def _apply_gradients_with_replicated_states( - self, synchronized_gradients, shard_models - ): - """Averages gradients across all shards and applies them once. - - This method is used when optimizer state sharding is disabled. It - calculates the average of the gradients for each variable across all - shards and applies the averaged gradients using the single, replicated - optimizer state. - - Args: - synchronized_gradients: The gradients after synchronization. - shard_models: The list of sharded models. - """ - num_vars = len(synchronized_gradients[0]) - averaged_grads_and_vars = [] - - for i in range(num_vars): - variable = synchronized_gradients[0][i][1] - grads_for_var = [ - shard_grads[i][0] - for shard_grads in synchronized_gradients - if shard_grads[i][0] is not None - ] - - if not grads_for_var: - continue - - if len(grads_for_var) > 1: - stacked_grads = ops.stack(grads_for_var, axis=0) - averaged_grad = ops.mean(stacked_grads, axis=0) - else: - averaged_grad = grads_for_var[0] - - averaged_grads_and_vars.append((averaged_grad, variable)) - - if averaged_grads_and_vars: - self.base_optimizer.apply_gradients(averaged_grads_and_vars) - - def _apply_gradients_with_sharded_states( - self, synchronized_gradients, shard_models - ): - """Applies gradients to each shard using its local optimizer state. - - Args: - synchronized_gradients: The gradients after synchronization. - shard_models: The list of sharded models. - """ - for shard_idx in range(self.world_size): - local_states = self._get_local_optimizer_states(shard_idx) - shard_optimizer = shard_models[shard_idx].optimizer - - self._update_optimizer_internal_state(shard_optimizer, local_states) - - shard_grads_and_vars = synchronized_gradients[shard_idx] - shard_optimizer.apply_gradients(shard_grads_and_vars) - - self._update_global_sharded_states(shard_optimizer, shard_idx) - - def _get_local_optimizer_states(self, shard_idx): - """Constructs the state dictionary for a single shard. - - Args: - shard_idx: The index of the shard for which to retrieve the state. - - Returns: - A dictionary containing the local optimizer state for the specified - shard. - """ - local_states = {} - for state_name, state_value in self.sharded_states.items(): - if isinstance(state_value, dict): - local_states[state_name] = {} - for param_name, param_states in state_value.items(): - local_states[state_name][param_name] = param_states[ - shard_idx - ] - else: - local_states[state_name] = state_value[shard_idx] - return local_states - - def _update_optimizer_internal_state(self, optimizer, local_states): - """Assigns local sharded state values to the optimizer's variables. - - Args: - optimizer: The optimizer instance for a specific shard. - local_states: The dictionary of local states for that shard. - """ - if not optimizer.built: - return - - for var in optimizer.variables: - if var is optimizer.iterations: - if "iterations" in local_states: - var.assign(local_states["iterations"]) - continue - - param = self._state_variable_to_parameter.get(var.path, None) - slot_name = self._variable_to_slot_name.get(var.path) - - if ( - param - and slot_name - and slot_name in local_states - and param.path in local_states[slot_name] - ): - local_param_state = local_states[slot_name][param.path] - if var.shape == local_param_state.shape: - var.assign(local_param_state) - - def _update_global_sharded_states(self, optimizer, shard_idx): - """Updates the main sharded_states dictionary after a gradient step. - - Args: - optimizer: The optimizer instance for a specific shard. - shard_idx: The index of the shard that was updated. - """ - if not optimizer.built: - return - - for var in optimizer.variables: - if var is optimizer.iterations: - self.sharded_states["iterations"][shard_idx] = ( - ops.convert_to_numpy(var) - ) - continue - - param = self._state_variable_to_parameter.get(var.path, None) - slot_name = self._variable_to_slot_name.get(var.path) - - if ( - param - and slot_name - and slot_name in self.sharded_states - and param.path in self.sharded_states[slot_name] - ): - self.sharded_states[slot_name][param.path][shard_idx] = ( - ops.convert_to_numpy(var) - ) - - def _synchronize_gradients(self, gradients_and_vars): - """Synchronizes gradients across shards based on tensor parallel rules. - - Specifically, it performs an all-reduce operation on gradients of - weights that are split along a "column" dimension in tensor parallelism. - Other gradients are passed through unchanged. - - Args: - gradients_and_vars: The list of (gradient, variable) lists from - all shards. - - Returns: - The list of (gradient, variable) lists after synchronization. - """ - if not self.tensor_parallel_config: - return gradients_and_vars - - rules = self.tensor_parallel_config.state_rules.items() - column_parallel_patterns = { - pattern - for pattern, action in rules - if hasattr(action, "sharding_type") - and action.sharding_type == "column" - } - - if not column_parallel_patterns: - return gradients_and_vars - - num_weights = len(gradients_and_vars[0]) - for i in range(num_weights): - variable = gradients_and_vars[0][i][1] - var_name = getattr(variable, "path", getattr(variable, "name", "")) - - if any( - re.search(pattern, var_name) - for pattern in column_parallel_patterns - ): - grads_to_reduce = [ - g_and_v[i][0] - for g_and_v in gradients_and_vars - if g_and_v[i][0] is not None - ] - if grads_to_reduce: - synced_grad = self._allreduce_gradients(grads_to_reduce)[0] - for shard_idx in range(self.world_size): - gradients_and_vars[shard_idx][i] = ( - synced_grad, - variable, - ) - return gradients_and_vars - - def _allreduce_gradients(self, gradients): - """Performs a mean all-reduce operation on a list of gradients. - - If a distributed backend is available, it uses it. Otherwise, it - falls back to a local mean calculation. - - Args: - gradients: A list of gradients (one from each shard) to be averaged. - - Returns: - A list where each element is the mean of the input gradients. - """ - if not gradients: - return [] - - if distributed_backend.is_multi_device_capable(): - all_reduce_fn = distributed_backend.get_communication_ops()[ - "all_reduce" - ] - numpy_grad = ops.convert_to_numpy(gradients[0]) - synced_numpy = all_reduce_fn(numpy_grad, op="mean") - synced_tensor = ops.convert_to_tensor(synced_numpy) - return [synced_tensor for _ in range(self.world_size)] - - stacked_grads = keras.ops.stack( - [ops.convert_to_tensor(g) for g in gradients], axis=0 - ) - mean_grad = ops.mean(stacked_grads, axis=0) - return [mean_grad for _ in range(len(gradients))] - - def get_weights(self): - """Returns the weights of the base optimizer. - - Returns: - A list of NumPy arrays representing the optimizer's state variables. - """ - return [ - ops.convert_to_numpy(var) for var in self.base_optimizer.variables - ] - - def set_weights(self, weights): - """Sets the weights of the base optimizer. - - Args: - weights: A list of NumPy arrays to set as the optimizer's state. - """ - self.base_optimizer.set_weights(weights) - - def enable_optimizer_state_sharding(self, variables): - """Enables and initializes optimizer state sharding. - - This method is called from `build()`, which is guarded from running - multiple times. We can assume this should always execute. - - Args: - variables: A list of model variables to be optimized. - """ - self.shard_optimizer_states = True - self._variables = variables - self._initialize_sharded_states() - - -class TensorParallelOptimizer(optimizers.Optimizer): - """A Keras Optimizer wrapper for tensor-parallel distributed training. - - This optimizer wraps a standard Keras optimizer (e.g., Adam, SGD) and - delegates the complex tasks of state management and gradient synchronization - to a `CoordinatedOptimizer` instance. It is designed to work with models - that have been sharded for tensor parallelism. - - When `apply_gradients` is called with a list of gradient lists (one for each - model shard), it uses the `CoordinatedOptimizer` to handle synchronization - and state sharding. Otherwise, it behaves like the base optimizer. - - Args: - base_optimizer: A Keras optimizer instance or a string identifier - (e.g., 'adam', 'sgd'). - world_size: The total number of devices/processes in the distributed - setup. - distributed_backend: The distributed communication backend to use. - Defaults to "auto". - tensor_parallel_config: An optional configuration object that defines - rules for tensor parallelism. Defaults to `None`. - - Example: - - ```python - import keras - - # Assume model variables and gradients from 4 shards exist. - # The structure is: list[list[tuple[gradient, variable]]] - trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] - sharded_grads_and_vars = [ - [(keras.ops.ones_like(v), v) for v in trainable_vars] - for _ in range(4) # 4 shards - ] - - # 1. Wrap a standard Keras optimizer. - base_optimizer = keras.optimizers.Adam() - optimizer = TensorParallelOptimizer(base_optimizer, world_size=4) - optimizer.build(trainable_vars) - - # 2. Apply the sharded gradients. - # The optimizer will handle synchronization (e.g., all-reduce) internally. - optimizer.apply_gradients(sharded_grads_and_vars) - ``` - """ - - def __init__( - self, - base_optimizer, - world_size, - distributed_backend="auto", - tensor_parallel_config=None, - ): - """Initializes the TensorParallelOptimizer.""" - if isinstance(base_optimizer, str): - base_optimizer_instance = optimizers.get(base_optimizer) - else: - base_optimizer_instance = base_optimizer - - learning_rate = base_optimizer_instance.learning_rate - if callable(learning_rate): - lr_value = float(ops.convert_to_numpy(learning_rate(0))) - else: - lr_value = float(ops.convert_to_numpy(learning_rate)) - - super().__init__( - learning_rate=lr_value, - name=f"TensorParallel_{base_optimizer_instance.name}", - ) - - self.base_optimizer = base_optimizer_instance - self.world_size = world_size - self.distributed_backend = distributed_backend - self.coordinated_optimizer = CoordinatedOptimizer( - self.base_optimizer, - world_size, - distributed_backend=distributed_backend, - tensor_parallel_config=tensor_parallel_config, - ) - - def apply_gradients(self, grads_and_vars, **kwargs): - """Applies gradients to the model variables. - - If `grads_and_vars` is a list of lists, it's assumed to be from - sharded models, and the `CoordinatedOptimizer` is used. Otherwise, - it calls the `base_optimizer`'s `apply_gradients` directly. - - Args: - grads_and_vars: A list of (gradient, variable) tuples, or a list - of such lists if running in a sharded context. - **kwargs: Additional arguments. `shard_models` can be passed to - provide the list of model shards. - """ - is_sharded_grads = ( - isinstance(grads_and_vars, list) - and grads_and_vars - and isinstance(grads_and_vars[0], list) - ) - if is_sharded_grads: - shard_models = kwargs.get("shard_models", []) - self.coordinated_optimizer.apply_gradients( - grads_and_vars, shard_models - ) - else: - self.base_optimizer.apply_gradients(grads_and_vars) - - def get_config(self): - """Returns the configuration of the optimizer. - - Returns: - A dictionary containing the optimizer's configuration. - """ - from keras.src import saving - - config = super().get_config() - config.pop("learning_rate", None) - config.pop("name", None) - - config.update( - { - "base_optimizer": saving.serialize_keras_object( - self.base_optimizer - ), - "world_size": self.world_size, - "distributed_backend": self.distributed_backend, - } - ) - return config - - def update_step(self, gradient, variable, *args, **kwargs): - """Performs a single optimization step. - - Delegates the update step to the base optimizer if it has a custom - `update_step` implementation, otherwise falls back to the parent - optimizer's logic. - - Args: - gradient: The gradient tensor. - variable: The variable to be updated. - *args: Positional arguments passed to the update function. - **kwargs: Keyword arguments passed to the update function. - """ - if hasattr(self.base_optimizer, "update_step"): - return self.base_optimizer.update_step( - gradient, variable, *args, **kwargs - ) - - return super().update_step(gradient, variable, *args, **kwargs) - - @classmethod - def from_config(cls, config): - """Creates an optimizer from its configuration. - - Args: - config: A Python dictionary, typically the output of `get_config`. - - Returns: - A `TensorParallelOptimizer` instance. - """ - from keras.src import saving - - base_optimizer_config = config.pop("base_optimizer") - base_optimizer = saving.deserialize_keras_object(base_optimizer_config) - - init_kwargs = { - "world_size": config.get("world_size"), - "distributed_backend": config.get("distributed_backend", "auto"), - "tensor_parallel_config": config.get("tensor_parallel_config"), - } - - return cls(base_optimizer=base_optimizer, **init_kwargs) - - def build(self, variables): - """Builds the optimizer and initializes sharded states. - - This method is called the first time the optimizer is used. It builds - the base optimizer and then triggers the `CoordinatedOptimizer` to - initialize its sharded states. - - Args: - variables: A list of model variables to be optimized. - """ - if self.built: - return - - self.base_optimizer.build(variables) - if variables: - iterations = self.base_optimizer.iterations - original_iterations_val = None - if iterations is not None: - original_iterations_val = ops.convert_to_numpy(iterations.value) - - zero_grads = [ops.zeros_like(v) for v in variables] - self.base_optimizer.apply_gradients(zip(zero_grads, variables)) - - if iterations is not None and original_iterations_val is not None: - iterations.assign(original_iterations_val) - - self.coordinated_optimizer.enable_optimizer_state_sharding(variables) - super().build(variables) - - def get_weights(self): - """Returns the weights of the base optimizer. - - Returns: - A list of NumPy arrays representing the optimizer's state variables. - """ - return self.coordinated_optimizer.get_weights() - - def set_weights(self, weights): - """Sets the weights of the base optimizer. - - Args: - weights: A list of NumPy arrays to set as the optimizer's state. - """ - self.coordinated_optimizer.set_weights(weights) - - @property - def variables(self): - """Returns the list of variables from the base optimizer. - - Returns: - A list of state variables of the base optimizer. - """ - return self.base_optimizer.variables - - @property - def learning_rate(self): - """Provides access to the learning rate of the base optimizer.""" - return self.base_optimizer.learning_rate - - @learning_rate.setter - def learning_rate(self, value): - """Sets the learning rate of the base optimizer.""" - self.base_optimizer.learning_rate = value - - @property - def iterations(self): - """ - Returns the training iteration count directly from the base optimizer. - """ - return self.base_optimizer.iterations diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py deleted file mode 100644 index 6841e4d01a36..000000000000 --- a/keras/src/distribution/tensor_parallel/tensor_layout.py +++ /dev/null @@ -1,154 +0,0 @@ -import keras - - -class LayoutAction: - def __call__(self, tensor, rank): - """Applies the distribution action to a tensor for a specific worker. - - Args: - tensor: The input tensor to be distributed. - rank: The integer rank of the current worker/device. - - Raises: - NotImplementedError: This is an abstract method and must be - implemented by subclasses. - - Returns: - A shard or transformation of the input tensor specific to the given - rank. - """ - raise NotImplementedError - - def undo(self, tensors): - """Reverses the distribution action, reconstructing the original tensor. - - Args: - tensors: A sequence of tensor shards, one from each worker. - - Raises: - NotImplementedError: This is an abstract method and must be - implemented by subclasses. - - Returns: - The reconstructed, single tensor. - """ - raise NotImplementedError - - -class _ConcatenateMixin: - """A mixin class providing a common `undo` method via concatenation. - - This class is intended to be used as a mixin for `LayoutAction` subclasses - that can be undone by simple concatenation along a specified axis. - """ - - def undo(self, tensors): - if self.dim == -1: - dim = keras.ops.ndim(tensors[0]) - 1 - else: - dim = self.dim - return keras.ops.concatenate(tensors, axis=dim) - - -class Split(_ConcatenateMixin, LayoutAction): - """Splits a tensor into shards along a specified dimension. - - This is an internal utility used by a higher-level distribution API. - It implements sharding by slicing a tensor along one of its axes. - It handles cases where the dimension size is not perfectly divisible by the - number of workers by distributing the remainder elements one by one to the - first few workers. - - The `undo` operation is provided by the `_ConcatenateMixin`. - """ - - def __init__(self, world_size, dim, sharding_type="auto"): - """Initializes the Split action. - - Args: - world_size: The total number of workers/shards. - dim: The dimension along which to split the tensor. If -1, the - last dimension is used. - sharding_type: If `dim` is -1, this can be 'row' (dim=0) or - 'column' (dim=1) to infer the split axis for 2D tensors. - Defaults to "auto". - """ - super().__init__() - self.world_size = world_size - self.dim = dim - self.sharding_type = sharding_type - - if dim == -1 and sharding_type != "auto": - if sharding_type == "row": - self.dim = 0 - elif sharding_type == "column": - self.dim = 1 - - def __call__(self, tensor, rank): - """Splits the tensor and returns the shard corresponding to the rank. - - This method calculates the correct slice of the tensor for a given - worker rank, handling uneven distributions gracefully. - - Args: - tensor: The full tensor to be sharded. - rank: The rank of the worker for which to get the shard. - - Returns: - A tensor shard corresponding to the given rank. - """ - if self.dim == -1: - dim = keras.ops.ndim(tensor) - 1 - else: - dim = self.dim - - total_size = tensor.shape[dim] - split_size = total_size // self.world_size - remainder = total_size % self.world_size - - start_idx = rank * split_size + min(rank, remainder) - end_idx = start_idx + split_size + (1 if rank < remainder else 0) - - slices = [slice(None)] * keras.ops.ndim(tensor) - slices[dim] = slice(start_idx, end_idx) - return tensor[tuple(slices)] - - -class LayoutMap: - """A mapping that defines layout rules for model states and outputs. - - This is an internal configuration object used to hold layout rules for - how model variables and layer outputs should be distributed across a set - of devices. It acts as a container for `LayoutAction` instances. - - Attributes: - state_rules: A dictionary mapping variable names or patterns to - `LayoutAction` instances. - output_rules: A dictionary mapping layer output names or - patterns to `LayoutAction` instances. - """ - - def __init__(self, state_rules, output_rules): - """Initializes the LayoutMap. - - Args: - state_rules: A dictionary of distribution rules for model states. - output_rules: A dictionary of distribution rules for model outputs. - """ - self.state_rules = state_rules - self.output_rules = output_rules - - def create_collective_ops(self, devices): - """Creates the necessary collective communication operations. - - This method is a placeholder for backend-specific logic that would - translate the layout rules into actual communication primitives - (e.g., all-gather, reduce-scatter). - - Args: - devices: A sequence of device identifiers. - - Returns: - The `LayoutMap` instance itself, allowing for method chaining. - """ - return self