diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 33a2cc3c5160..6ef3c9059d0c 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -414,10 +414,7 @@ def _initialize(self, value): raise NotImplementedError def _initialize_with_initializer(self, initializer): - value = self._convert_to_tensor( - initializer(self._shape, dtype=self._dtype) - ) - self._initialize(value) + raise NotImplementedError def _convert_to_tensor(self, value, dtype=None): raise NotImplementedError diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 7dc5a98fb8d5..48a18f426518 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -30,9 +30,7 @@ def __init__(self, *args, layout=None, **kwargs): self._layout = layout super().__init__(*args, **kwargs) - def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib - self._shape = self._validate_shape(value.shape) + def set_tensor_layout(self): # We can't import the keras/distribution/distribution_lib # due to circular dependency. distribution = global_state.get_global_attribute("distribution") @@ -44,8 +42,38 @@ def _initialize(self, value): self._layout = tensor_layout.backend_layout else: self._layout = tensor_layout + + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + self.set_tensor_layout() self._direct_assign(value) + def check_distributed_init(self, initializer, init_layout): + # Check if 'layout' parameter is supported in the initializer call + import inspect + + sig = inspect.signature(initializer.__call__) + layout_supported = "layout" in sig.parameters + # Check if PartitionSpec has any non-None values + spec = getattr(init_layout, "spec", None) + partition_spec = spec if spec is not None else () + is_partitioned = any(dim is not None for dim in partition_spec) + return layout_supported and init_layout is not None and is_partitioned + + def _initialize_with_initializer(self, initializer): + init_layout = get_initialization_layout(self.path) + # Use layout-aware initialization for distributed embeddings + if self.check_distributed_init(initializer, init_layout): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype, layout=init_layout) + ) + else: + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): if self._layout is not None: value = distribution_lib.distribute_variable(value, self._layout) @@ -112,6 +140,12 @@ def __init__( # The real value is now set in self._value, sync it to raw_value object.__setattr__(self, "raw_value", self._value) + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + @property def _value(self): if hasattr(self, "raw_value"): @@ -235,6 +269,25 @@ def value(self): Variable = NnxVariable +def get_initialization_layout(path): + distribution = global_state.get_global_attribute("distribution") + if distribution is None: + return None + layout_map = getattr(distribution, "_layout_map", None) + if layout_map is None: + return None + layout_obj = layout_map.get(path) + if layout_obj is None: + return None + from keras.src.distribution import TensorLayout + + if isinstance(layout_obj, TensorLayout): + layout_obj = layout_obj.backend_layout + if isinstance(layout_obj, jax.sharding.NamedSharding): + return layout_obj + return None + + def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): if ragged: raise ValueError("`ragged=True` is not supported with jax backend") diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 1407c008910e..921b3b08f133 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -9,6 +9,110 @@ from keras.src.utils import rng_utils +def _distribute_initializer( + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None +): + """ + Distribution-aware initializer for JAX backend. + This function will create a Jax random array and + distribute it according to the current layout. + Args: + init_func: A functools.partial-wrapped object that takes the seed + as argument and returns a jax.Array. Must have shape and dtype + already bound via partial. + mean: Mean of distribution (applied to normal/truncated_normal). + stddev: Standard deviation of the distribution. + seed: JAX compatible seed array, if None use the Seed generator. + layout: TensorLayout for the distributed tensor. + Returns: + A distributed jax array. + Raises: + ValueError: If init_func or seed is None. + If init_func.func is not a supported random function. + Supported jax.random func: normal, truncated_normal, uniform + TypeError: If init_func is not a functools.partial object + or seed is not a Jax array. + + """ + import warnings + from functools import partial + + # Draw seed from the seed generator if seed is not a Jax Array + # It is imperative for seed generation to happen before jit compilation + if seed is None or not isinstance(seed, jax.Array): + seed = seed_generator.draw_seed(None)[0] + seed = jax.random.key(seed) + + # Validate all required arguments + if init_func is None or init_func.func.__name__ not in [ + "normal", + "truncated_normal", + "uniform", + ]: + raise ValueError( + "init_func cannot be None or " + "Unsupported initializer: {init_func.func.__name__}." + "only JAX-compatible random initializers are supported. " + "Supported jax.random funcs: normal, truncated_normal, uniform" + ) + + # Ensure init_func is a partial + if not isinstance(init_func, partial): + raise TypeError( + f"init_func must be functools.partial object, got {type(init_func)}" + "init_func is a jax.random.* function with shape and " + "dtype bound via partial" + ) + + # Shard based on tensor layout + if layout is None: + warnings.warn( + f"The layout is {layout}, sharding will default to single device" + ) + + sharding = None + else: + if not isinstance(layout, jax.sharding.NamedSharding): + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + layout = _to_backend_layout(layout) + else: + raise TypeError( + f"layout must be Keras TensorLayout or " + f"jax.sharding.NamedSharding, got {type(layout)}" + ) + sharding = layout + + # JAX PRNG key handling within JIT: + # The key is passed directly to jax.random.* functions which are + # JIT-compatible and functional. JAX automatically ensures different + # random values per shard when out_shardings is specified. + try: + compiled_init = jax.jit( + lambda seed: init_func(seed), + out_shardings=sharding, + ) + sample = compiled_init(seed) + + except RuntimeError as e: + warnings.warn( + f"Sharding at initialization failed due to: {e}, " + f"falling back to single device" + ) + compiled_init = jax.jit( + lambda seed: init_func(seed), + out_shardings=None, + ) + sample = compiled_init(seed) + + # Apply mean/stddev only for distributions where it makes sense + if init_func.func in (jax.random.normal, jax.random.truncated_normal): + return sample * stddev + mean + elif init_func.func == jax.random.uniform: + return sample + + def list_devices(device_type=None): """Return all the available devices based on the device type. @@ -260,3 +364,120 @@ def _to_backend_layout(tensor_layout): partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) jax_mesh = tensor_layout.device_mesh.backend_mesh return jax.sharding.NamedSharding(jax_mesh, partition_spec) + + +def _distribute_initializer( + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None +): + """ + Distribution-aware token embedding initializer for JAX backend. + + This function will create a Jax random array and + distribute it according to the current token embedding layout. + + Args: + init_func: A functools.partial-wrapped object that takes the seed + as argument and returns a jax.Array. Must have shape and dtype + already bound via partial. + mean: Mean of distribution (applied to normal/truncated_normal). + stddev: Standard deviation of the distribution. + seed: Random seed for initialization. + layout: TensorLayout for the distributed tensor. + + Returns: + A distributed jax array. + + Raises: + ValueError: If init_func or seed is None. + If init_func.func is not a supported random function. + Supported jax.random func: normal, truncated_normal, uniform + TypeError: If init_func is not a functools.partial object. + """ + import warnings + from functools import partial + + # Create SeedGenerator to ensure backend variable exists + # For future state tracking for distributed keys, add + # attributes for base/split keys and number of devices sharded. + if isinstance(seed, jax.Array): + seed_gen = seed_generator.SeedGenerator(seed=int(seed[0])) + elif isinstance(seed, int): + seed_gen = seed_generator.SeedGenerator(seed=seed) + elif isinstance(seed, seed_generator.SeedGenerator): + seed_gen = seed + else: + raise ValueError( + f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}" + ) + + # Extract the state value as JAX array + jax_seed = seed_gen.state.value + + # Convert to JAX PRNG key format (swap counter and seed value) + jax_compatible_seed = jax.numpy.array( + [jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32 + ) + + # Validate all required arguments + if init_func is None or init_func.func.__name__ not in [ + "normal", + "truncated_normal", + "uniform", + ]: + raise ValueError( + "init_func cannot be None or " + "Unsupported initializer: {init_func.func.__name__}." + "only JAX-compatible random initializers are supported. " + "Supported jax.random funcs: normal, truncated_normal, uniform" + ) + + # Ensure init_func is a partial + if not isinstance(init_func, partial): + raise TypeError( + f"init_func must be functools.partial object, got {type(init_func)}" + "init_func is a jax.random.* function with shape and " + "dtype bound via partial" + ) + + # Shard based on tensor layout + if layout is None: + warnings.warn( + f"The layout is {layout}, sharding will default to single device" + ) + sharding = None + else: + sharding = _to_backend_layout(layout) + + # JAX PRNG key handling within JIT: + # The key is passed directly to jax.random.* functions which are + # JIT-compatible and functional. JAX automatically ensures different + # random values per shard when out_shardings is specified. + try: + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=sharding, + ) + sample = compiled_init(jax_compatible_seed) + except RuntimeError as e: + warnings.warn( + f"Sharding failed due to: {e}, falling back to single device" + ) + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=None, + ) + sample = compiled_init(jax_compatible_seed) + + # Store the SeedGenerator for state tracking + seed = seed_gen.next() + + # Apply mean/stddev only for distributions where it makes sense + if init_func.func in (jax.random.normal, jax.random.truncated_normal): + return sample * stddev + mean + elif init_func.func == jax.random.uniform: + # Uniform doesn't use mean/stddev - warn + if mean != 0.0 or stddev != 1.0: + warnings.warn( + "mean and stddev are ignored for uniform distribution" + ) + return sample diff --git a/keras/src/backend/jax/random.py b/keras/src/backend/jax/random.py index 79901696339f..f025ddc62dd3 100644 --- a/keras/src/backend/jax/random.py +++ b/keras/src/backend/jax/random.py @@ -1,3 +1,5 @@ +from functools import partial + import jax from keras.src.backend.config import floatx @@ -8,24 +10,91 @@ def jax_draw_seed(seed): if isinstance(seed, jax.Array): + if seed.ndim == 0: + return jax.random.key(seed) + elif seed.ndim == 1 and seed.shape == (2,): + return seed + else: + seed = draw_seed(seed) return seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): + dtype = dtype or floatx() + seed = jax_draw_seed(seed) + if layout is not None: + from keras.src.backend import distribution_lib + + init_func = partial( + jax.random.normal, + shape=shape, + dtype=dtype, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=mean, + stddev=stddev, + seed=seed, + layout=layout, + ) else: - return draw_seed(seed) + sample = jax.random.normal(seed, shape=shape, dtype=dtype) + return sample * stddev + mean -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = jax_draw_seed(seed) - sample = jax.random.normal(seed, shape=shape, dtype=dtype) - return sample * stddev + mean + if layout is not None: + from keras.src.backend import distribution_lib + + init_func = partial( + jax.random.truncated_normal, + shape=shape, + dtype=dtype, + lower=-2.0, + upper=2.0, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=mean, + stddev=stddev, + seed=seed, + layout=layout, + ) + else: + sample = jax.random.truncated_normal( + seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype + ) + return sample * stddev + mean -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = jax_draw_seed(seed) - return jax.random.uniform( - seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval - ) + if layout is not None: + from keras.src.backend import distribution_lib + + init_func = partial( + jax.random.uniform, + shape=shape, + dtype=dtype, + minval=minval, + maxval=maxval, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=None, + stddev=None, + seed=seed, + layout=layout, + ) + else: + return jax.random.uniform( + seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval + ) def categorical(logits, num_samples, dtype="int32", seed=None): @@ -46,15 +115,6 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): ) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - dtype = dtype or floatx() - seed = jax_draw_seed(seed) - sample = jax.random.truncated_normal( - seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype - ) - return sample * stddev + mean - - def _get_concrete_noise_shape(inputs, noise_shape): if noise_shape is None: return inputs.shape diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 16b2303e5e43..da0b5b4d153c 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -23,6 +23,12 @@ class Variable(KerasVariable): def _initialize(self, value): self._value = value + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): self._value = np.array(value, dtype=self._dtype) diff --git a/keras/src/backend/numpy/random.py b/keras/src/backend/numpy/random.py index f8fd65aa38ba..28aa5c64f243 100644 --- a/keras/src/backend/numpy/random.py +++ b/keras/src/backend/numpy/random.py @@ -7,14 +7,14 @@ from keras.src.random.seed_generator import make_default_seed -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed) return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed) @@ -40,7 +40,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): return output -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 93f9f5819c8b..a52eed6d9d38 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -572,6 +572,12 @@ def _initialize(self, value): ) self._value = OpenVINOKerasTensor(value_const.output(0)) + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): self._value = value diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py index 38de21294677..9d40a93181b5 100644 --- a/keras/src/backend/openvino/random.py +++ b/keras/src/backend/openvino/random.py @@ -12,7 +12,7 @@ from keras.src.random.seed_generator import make_default_seed -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed.data) @@ -20,7 +20,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): return OpenVINOKerasTensor(ov_opset.constant(normal_const).output(0)) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed_val = draw_seed(seed) if isinstance(seed_val, OpenVINOKerasTensor): @@ -96,7 +96,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): ) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed.data) diff --git a/keras/src/backend/tensorflow/random.py b/keras/src/backend/tensorflow/random.py index e807b0de9aab..4b935faf8027 100644 --- a/keras/src/backend/tensorflow/random.py +++ b/keras/src/backend/tensorflow/random.py @@ -20,7 +20,7 @@ def _cast_seed(seed): return seed -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_normal( @@ -28,7 +28,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): ) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_uniform( @@ -61,7 +61,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): return tf.cast(output, dtype) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_truncated_normal( diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 877dc6909ea1..a10c11dab959 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -109,6 +109,12 @@ def _initialize(self, value): requires_grad=self.trainable, ).to(get_device()) + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): with torch.no_grad(): self.value.copy_(value) diff --git a/keras/src/backend/torch/random.py b/keras/src/backend/torch/random.py index e080731952e6..1413c1e795b3 100644 --- a/keras/src/backend/torch/random.py +++ b/keras/src/backend/torch/random.py @@ -25,7 +25,7 @@ def torch_seed_generator(seed): return generator -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() dtype = to_torch_dtype(dtype) # Do not use generator during symbolic execution. @@ -64,7 +64,7 @@ def categorical(logits, num_samples, dtype="int32", seed=None): ).type(dtype) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() dtype = to_torch_dtype(dtype) requested_shape = shape @@ -108,7 +108,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): ) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = to_torch_dtype(dtype) # Take a larger standard normal dist, discard values outside 2 * stddev # Offset by mean and stddev diff --git a/keras/src/initializers/random_initializers.py b/keras/src/initializers/random_initializers.py index ad1123e2a18f..b2ee295bd2fa 100644 --- a/keras/src/initializers/random_initializers.py +++ b/keras/src/initializers/random_initializers.py @@ -2,19 +2,23 @@ from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend import backend from keras.src.backend import random from keras.src.initializers.initializer import Initializer from keras.src.saving import serialization_lib +from keras.src.utils import jax_utils class RandomInitializer(Initializer): def __init__(self, seed=None): self._init_seed = seed - if seed is None: + if seed is None and backend() == "jax": + seed = jax_utils.get_jax_random_seed(seed) + elif seed is None: seed = random.make_default_seed() elif isinstance(seed, dict): seed = serialization_lib.deserialize_keras_object(seed) - elif not isinstance(seed, (int, random.SeedGenerator)): + elif not isinstance(seed, (random.SeedGenerator, int)): raise ValueError( "`seed` argument should be an instance of " "`keras.random.SeedGenerator()` or an integer. " @@ -68,13 +72,14 @@ def __init__(self, mean=0.0, stddev=0.05, seed=None): self.stddev = stddev super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): return random.normal( shape=shape, mean=self.mean, stddev=self.stddev, seed=self.seed, dtype=dtype, + layout=layout, ) def get_config(self): @@ -127,13 +132,14 @@ def __init__(self, mean=0.0, stddev=0.05, seed=None): self.stddev = stddev super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): return random.truncated_normal( shape=shape, mean=self.mean, stddev=self.stddev, seed=self.seed, dtype=dtype, + layout=layout, ) def get_config(self): @@ -183,13 +189,14 @@ def __init__(self, minval=-0.05, maxval=0.05, seed=None): self.maxval = maxval super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): return random.uniform( shape=shape, minval=self.minval, maxval=self.maxval, seed=self.seed, dtype=dtype, + layout=layout, ) def get_config(self): @@ -282,7 +289,7 @@ def __init__( self.distribution = distribution super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): scale = self.scale fan_in, fan_out = compute_fans(shape) if self.mode == "fan_in": @@ -291,20 +298,36 @@ def __call__(self, shape, dtype=None): scale /= max(1.0, fan_out) else: scale /= max(1.0, (fan_in + fan_out) / 2.0) + if self.distribution == "truncated_normal": stddev = math.sqrt(scale) / 0.87962566103423978 return random.truncated_normal( - shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + shape, + mean=0.0, + stddev=stddev, + dtype=dtype, + seed=self.seed, + layout=layout, ) elif self.distribution == "untruncated_normal": stddev = math.sqrt(scale) return random.normal( - shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + shape, + mean=0.0, + stddev=stddev, + dtype=dtype, + seed=self.seed, + layout=layout, ) else: limit = math.sqrt(3.0 * scale) return random.uniform( - shape, minval=-limit, maxval=limit, dtype=dtype, seed=self.seed + shape, + minval=-limit, + maxval=limit, + dtype=dtype, + seed=self.seed, + layout=layout, ) def get_config(self): diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py index dd2adbc13bbe..d5f380e51add 100644 --- a/keras/src/random/seed_generator.py +++ b/keras/src/random/seed_generator.py @@ -118,21 +118,10 @@ def from_config(cls, config): def global_seed_generator(): if jax_utils.is_in_jax_tracing_scope(): - raise ValueError( - "[JAX RNG] When tracing a JAX function, " - "you should only use seeded random ops, e.g. " - "you should create a `SeedGenerator` instance, attach it " - "to your layer/model, and pass the instance as the `seed` " - "argument when calling random ops. Unseeded random ops " - "would get incorrectly traced by JAX and would become constant " - "after tracing. Example:\n\n" - "```\n" - "# Make sure to set the seed generator as a layer attribute\n" - "self.seed_generator = keras.random.SeedGenerator(seed=1337)\n" - "...\n" - "out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n" - "```" - ) + # When we are in Jax Tracing mode, we provide a lightweight + # object of the shape and dtype expected + return jax_utils.JAXTracingSeedGenerator() + gen = global_state.get_global_attribute("global_seed_generator") if gen is None: gen = SeedGenerator() diff --git a/keras/src/random/seed_generator_test.py b/keras/src/random/seed_generator_test.py index d1101e0a871a..344d7b5dabb7 100644 --- a/keras/src/random/seed_generator_test.py +++ b/keras/src/random/seed_generator_test.py @@ -84,11 +84,8 @@ def test_jax_tracing_with_global_seed_generator(self): def traced_function(): return seed_generator.global_seed_generator().next() - with self.assertRaisesRegex( - ValueError, - "When tracing a JAX function, you should only use seeded random", - ): - traced_function() + result = traced_function() + self.assertIsNotNone(result) def test_seed_generator_serialization(self): random_generator = seed_generator.SeedGenerator(seed=42) diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index d5375785f762..eae48933427c 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -1,3 +1,5 @@ +from jax import random + from keras.src import backend @@ -9,3 +11,24 @@ def is_in_jax_tracing_scope(x=None): if c.__name__ == "Tracer" and c.__module__.startswith("jax"): return True return False + + +def get_jax_random_seed(seed=None): + if is_in_jax_tracing_scope(): + # Constant dummy seed for Tracing + seed = 0 + else: + # Gathering seed from a seed generator + seed = backend.random.draw_seed(None)[0] + return seed + + +# Create a lightweight class that only provides shape/dtype info +class JAXTracingSeedGenerator: + def __init__(self): + self._shape = (2,) + self._dtype = "uint32" + + def next(self, ordered=False): + # Return a dummy key for tracing + return random.key(0) diff --git a/keras/src/utils/rng_utils.py b/keras/src/utils/rng_utils.py index dd45021d1c25..4d62264c2650 100644 --- a/keras/src/utils/rng_utils.py +++ b/keras/src/utils/rng_utils.py @@ -5,6 +5,7 @@ from keras.src import backend from keras.src.api_export import keras_export from keras.src.backend.common import global_state +from keras.src.random import seed_generator from keras.src.utils.module_utils import tensorflow as tf GLOBAL_RANDOM_SEED = "global_random_seed" @@ -60,6 +61,10 @@ def set_random_seed(seed): import torch torch.manual_seed(seed) + if backend.backend() == "jax": + # We create a global seed generator using the global random seed + gen = seed_generator.SeedGenerator(seed) + global_state.set_global_attribute("global_seed_generator", gen) def get_random_seed():