Skip to content

Commit 6173cdf

Browse files
Fix OOM Issue
1 parent bfde12b commit 6173cdf

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,80 @@ def _to_backend_layout(tensor_layout):
260260
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
261261
jax_mesh = tensor_layout.device_mesh.backend_mesh
262262
return jax.sharding.NamedSharding(jax_mesh, partition_spec)
263+
264+
265+
def _distribute_initializer(
266+
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
267+
):
268+
"""
269+
Distribution-aware token embedding initializer for JAX backend.
270+
271+
This function will create a Jax random array and
272+
distribute it according to the current token embedding layout.
273+
274+
Args:
275+
init_func: A functools.partial-wrapped object that takes the seed
276+
as argument and returns a jax.Array. Must have shape and dtype
277+
already bound via partial.
278+
mean: Mean of distribution (applied to normal/truncated_normal).
279+
stddev: Standard deviation of the distribution.
280+
seed: Random seed for initialization.
281+
layout: TensorLayout for the distributed tensor.
282+
283+
Returns:
284+
A distributed jax array.
285+
286+
Raises:
287+
ValueError: If init_func or seed is None.
288+
If init_func.func is not a supported random function.
289+
TypeError: If init_func is not a functools.partial object.
290+
"""
291+
import warnings
292+
from functools import partial
293+
294+
# Validate all required arguments
295+
if seed is None:
296+
raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.")
297+
298+
if init_func is None:
299+
raise ValueError(
300+
"init_func cannot be None. Shape and dtype info are required."
301+
)
302+
303+
# Ensure init_func is a partial
304+
if not isinstance(init_func, partial):
305+
raise TypeError(
306+
f"init_func must be functools.partial object, got {type(init_func)}"
307+
)
308+
309+
# Shard based on tensor layout
310+
if layout is None:
311+
warnings.warn(
312+
f"The layout is {layout}, sharding will default to single device"
313+
)
314+
sharding = None
315+
else:
316+
sharding = _to_backend_layout(layout)
317+
318+
# The init_func has static arguments baked in as per initializer.
319+
compiled_init = jax.jit(
320+
lambda seed: init_func(seed), out_shardings=sharding
321+
)
322+
323+
sample = compiled_init(seed)
324+
325+
# Apply mean/stddev only for distributions where it makes sense
326+
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
327+
return sample * stddev + mean
328+
elif init_func.func == jax.random.uniform:
329+
# Uniform doesn't use mean/stddev - warn
330+
if mean != 0.0 or stddev != 1.0:
331+
warnings.warn(
332+
"mean and stddev are ignored for uniform distribution"
333+
)
334+
return sample
335+
else:
336+
raise ValueError(
337+
f"Unsupported initializer: {init_func.func.__name__}. "
338+
f"Supported: normal, truncated_normal, uniform"
339+
)

0 commit comments

Comments
 (0)