@@ -260,3 +260,80 @@ def _to_backend_layout(tensor_layout):
260260 partition_spec = jax .sharding .PartitionSpec (* tensor_layout .axes )
261261 jax_mesh = tensor_layout .device_mesh .backend_mesh
262262 return jax .sharding .NamedSharding (jax_mesh , partition_spec )
263+
264+
265+ def _distribute_initializer (
266+ init_func = None , mean = 0.0 , stddev = 1.0 , seed = None , layout = None
267+ ):
268+ """
269+ Distribution-aware token embedding initializer for JAX backend.
270+
271+ This function will create a Jax random array and
272+ distribute it according to the current token embedding layout.
273+
274+ Args:
275+ init_func: A functools.partial-wrapped object that takes the seed
276+ as argument and returns a jax.Array. Must have shape and dtype
277+ already bound via partial.
278+ mean: Mean of distribution (applied to normal/truncated_normal).
279+ stddev: Standard deviation of the distribution.
280+ seed: Random seed for initialization.
281+ layout: TensorLayout for the distributed tensor.
282+
283+ Returns:
284+ A distributed jax array.
285+
286+ Raises:
287+ ValueError: If init_func or seed is None.
288+ If init_func.func is not a supported random function.
289+ TypeError: If init_func is not a functools.partial object.
290+ """
291+ import warnings
292+ from functools import partial
293+
294+ # Validate all required arguments
295+ if seed is None :
296+ raise ValueError ("seed cannot be None. Use keras.random.SeedGenerator." )
297+
298+ if init_func is None :
299+ raise ValueError (
300+ "init_func cannot be None. Shape and dtype info are required."
301+ )
302+
303+ # Ensure init_func is a partial
304+ if not isinstance (init_func , partial ):
305+ raise TypeError (
306+ f"init_func must be functools.partial object, got { type (init_func )} "
307+ )
308+
309+ # Shard based on tensor layout
310+ if layout is None :
311+ warnings .warn (
312+ f"The layout is { layout } , sharding will default to single device"
313+ )
314+ sharding = None
315+ else :
316+ sharding = _to_backend_layout (layout )
317+
318+ # The init_func has static arguments baked in as per initializer.
319+ compiled_init = jax .jit (
320+ lambda seed : init_func (seed ), out_shardings = sharding
321+ )
322+
323+ sample = compiled_init (seed )
324+
325+ # Apply mean/stddev only for distributions where it makes sense
326+ if init_func .func in (jax .random .normal , jax .random .truncated_normal ):
327+ return sample * stddev + mean
328+ elif init_func .func == jax .random .uniform :
329+ # Uniform doesn't use mean/stddev - warn
330+ if mean != 0.0 or stddev != 1.0 :
331+ warnings .warn (
332+ "mean and stddev are ignored for uniform distribution"
333+ )
334+ return sample
335+ else :
336+ raise ValueError (
337+ f"Unsupported initializer: { init_func .func .__name__ } . "
338+ f"Supported: normal, truncated_normal, uniform"
339+ )
0 commit comments