@@ -286,24 +286,53 @@ def _distribute_initializer(
286286 Raises:
287287 ValueError: If init_func or seed is None.
288288 If init_func.func is not a supported random function.
289+ Supported jax.random func: normal, truncated_normal, uniform
289290 TypeError: If init_func is not a functools.partial object.
290291 """
291292 import warnings
292293 from functools import partial
293294
294- # Validate all required arguments
295- if seed is None :
296- raise ValueError ("seed cannot be None. Use keras.random.SeedGenerator." )
295+ # Create SeedGenerator to ensure backend variable exists
296+ # For future state tracking for distributed keys, add
297+ # attributes for base/split keys and number of devices sharded.
298+ if isinstance (seed , jax .Array ):
299+ seed_gen = seed_generator .SeedGenerator (seed = int (seed [0 ]))
300+ elif isinstance (seed , int ):
301+ seed_gen = seed_generator .SeedGenerator (seed = seed )
302+ elif isinstance (seed , seed_generator .SeedGenerator ):
303+ seed_gen = seed
304+ else :
305+ raise ValueError (
306+ f"seed must be int, JAX array, or SeedGenerator, got { type (seed )} "
307+ )
297308
298- if init_func is None :
309+ # Extract the state value as JAX array
310+ jax_seed = seed_gen .state .value
311+
312+ # Convert to JAX PRNG key format (swap counter and seed value)
313+ jax_compatible_seed = jax .numpy .array (
314+ [jax_seed [1 ], jax_seed [0 ]], dtype = jax .numpy .uint32
315+ )
316+
317+ # Validate all required arguments
318+ if init_func is None or init_func .func .__name__ not in [
319+ "normal" ,
320+ "truncated_normal" ,
321+ "uniform" ,
322+ ]:
299323 raise ValueError (
300- "init_func cannot be None. Shape and dtype info are required."
324+ "init_func cannot be None or "
325+ "Unsupported initializer: {init_func.func.__name__}."
326+ "only JAX-compatible random initializers are supported. "
327+ "Supported jax.random funcs: normal, truncated_normal, uniform"
301328 )
302329
303330 # Ensure init_func is a partial
304331 if not isinstance (init_func , partial ):
305332 raise TypeError (
306333 f"init_func must be functools.partial object, got { type (init_func )} "
334+ "init_func is a jax.random.* function with shape and "
335+ "dtype bound via partial"
307336 )
308337
309338 # Shard based on tensor layout
@@ -315,12 +344,28 @@ def _distribute_initializer(
315344 else :
316345 sharding = _to_backend_layout (layout )
317346
318- # The init_func has static arguments baked in as per initializer.
319- compiled_init = jax .jit (
320- lambda seed : init_func (seed ), out_shardings = sharding
321- )
347+ # JAX PRNG key handling within JIT:
348+ # The key is passed directly to jax.random.* functions which are
349+ # JIT-compatible and functional. JAX automatically ensures different
350+ # random values per shard when out_shardings is specified.
351+ try :
352+ compiled_init = jax .jit (
353+ lambda jax_compatible_seed : init_func (jax_compatible_seed ),
354+ out_shardings = sharding ,
355+ )
356+ sample = compiled_init (jax_compatible_seed )
357+ except RuntimeError as e :
358+ warnings .warn (
359+ f"Sharding failed due to: { e } , falling back to single device"
360+ )
361+ compiled_init = jax .jit (
362+ lambda jax_compatible_seed : init_func (jax_compatible_seed ),
363+ out_shardings = None ,
364+ )
365+ sample = compiled_init (jax_compatible_seed )
322366
323- sample = compiled_init (seed )
367+ # Store the SeedGenerator for state tracking
368+ seed = seed_gen .next ()
324369
325370 # Apply mean/stddev only for distributions where it makes sense
326371 if init_func .func in (jax .random .normal , jax .random .truncated_normal ):
@@ -332,8 +377,3 @@ def _distribute_initializer(
332377 "mean and stddev are ignored for uniform distribution"
333378 )
334379 return sample
335- else :
336- raise ValueError (
337- f"Unsupported initializer: { init_func .func .__name__ } . "
338- f"Supported: normal, truncated_normal, uniform"
339- )
0 commit comments