Skip to content

Commit 949ec1d

Browse files
Address review feedback: improve error messages and add PRNG key handling comments
1 parent 6173cdf commit 949ec1d

File tree

1 file changed

+55
-15
lines changed

1 file changed

+55
-15
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,24 +286,53 @@ def _distribute_initializer(
286286
Raises:
287287
ValueError: If init_func or seed is None.
288288
If init_func.func is not a supported random function.
289+
Supported jax.random func: normal, truncated_normal, uniform
289290
TypeError: If init_func is not a functools.partial object.
290291
"""
291292
import warnings
292293
from functools import partial
293294

294-
# Validate all required arguments
295-
if seed is None:
296-
raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.")
295+
# Create SeedGenerator to ensure backend variable exists
296+
# For future state tracking for distributed keys, add
297+
# attributes for base/split keys and number of devices sharded.
298+
if isinstance(seed, jax.Array):
299+
seed_gen = seed_generator.SeedGenerator(seed=int(seed[0]))
300+
elif isinstance(seed, int):
301+
seed_gen = seed_generator.SeedGenerator(seed=seed)
302+
elif isinstance(seed, seed_generator.SeedGenerator):
303+
seed_gen = seed
304+
else:
305+
raise ValueError(
306+
f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}"
307+
)
297308

298-
if init_func is None:
309+
# Extract the state value as JAX array
310+
jax_seed = seed_gen.state.value
311+
312+
# Convert to JAX PRNG key format (swap counter and seed value)
313+
jax_compatible_seed = jax.numpy.array(
314+
[jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32
315+
)
316+
317+
# Validate all required arguments
318+
if init_func is None or init_func.func.__name__ not in [
319+
"normal",
320+
"truncated_normal",
321+
"uniform",
322+
]:
299323
raise ValueError(
300-
"init_func cannot be None. Shape and dtype info are required."
324+
"init_func cannot be None or "
325+
"Unsupported initializer: {init_func.func.__name__}."
326+
"only JAX-compatible random initializers are supported. "
327+
"Supported jax.random funcs: normal, truncated_normal, uniform"
301328
)
302329

303330
# Ensure init_func is a partial
304331
if not isinstance(init_func, partial):
305332
raise TypeError(
306333
f"init_func must be functools.partial object, got {type(init_func)}"
334+
"init_func is a jax.random.* function with shape and "
335+
"dtype bound via partial"
307336
)
308337

309338
# Shard based on tensor layout
@@ -315,12 +344,28 @@ def _distribute_initializer(
315344
else:
316345
sharding = _to_backend_layout(layout)
317346

318-
# The init_func has static arguments baked in as per initializer.
319-
compiled_init = jax.jit(
320-
lambda seed: init_func(seed), out_shardings=sharding
321-
)
347+
# JAX PRNG key handling within JIT:
348+
# The key is passed directly to jax.random.* functions which are
349+
# JIT-compatible and functional. JAX automatically ensures different
350+
# random values per shard when out_shardings is specified.
351+
try:
352+
compiled_init = jax.jit(
353+
lambda jax_compatible_seed: init_func(jax_compatible_seed),
354+
out_shardings=sharding,
355+
)
356+
sample = compiled_init(jax_compatible_seed)
357+
except RuntimeError as e:
358+
warnings.warn(
359+
f"Sharding failed due to: {e}, falling back to single device"
360+
)
361+
compiled_init = jax.jit(
362+
lambda jax_compatible_seed: init_func(jax_compatible_seed),
363+
out_shardings=None,
364+
)
365+
sample = compiled_init(jax_compatible_seed)
322366

323-
sample = compiled_init(seed)
367+
# Store the SeedGenerator for state tracking
368+
seed = seed_gen.next()
324369

325370
# Apply mean/stddev only for distributions where it makes sense
326371
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
@@ -332,8 +377,3 @@ def _distribute_initializer(
332377
"mean and stddev are ignored for uniform distribution"
333378
)
334379
return sample
335-
else:
336-
raise ValueError(
337-
f"Unsupported initializer: {init_func.func.__name__}. "
338-
f"Supported: normal, truncated_normal, uniform"
339-
)

0 commit comments

Comments
 (0)