Skip to content

Commit 66a47a1

Browse files
Updates post feedback
1 parent 949ec1d commit 66a47a1

File tree

10 files changed

+272
-50
lines changed

10 files changed

+272
-50
lines changed

keras/src/backend/common/variables.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,7 @@ def _initialize(self, value):
414414
raise NotImplementedError
415415

416416
def _initialize_with_initializer(self, initializer):
417-
value = self._convert_to_tensor(
418-
initializer(self._shape, dtype=self._dtype)
419-
)
420-
self._initialize(value)
417+
raise NotImplementedError
421418

422419
def _convert_to_tensor(self, value, dtype=None):
423420
raise NotImplementedError

keras/src/backend/jax/core.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,51 @@ def __init__(self, *args, layout=None, **kwargs):
3030
self._layout = layout
3131
super().__init__(*args, **kwargs)
3232

33-
def _initialize(self, value):
34-
# Note that variable.shape is needed by distribution_lib
35-
self._shape = self._validate_shape(value.shape)
33+
def set_tensor_layout(self):
3634
# We can't import the keras/distribution/distribution_lib
3735
# due to circular dependency.
38-
distribution = global_state.get_global_attribute("distribution")
39-
if self._layout is None and distribution is not None:
40-
tensor_layout = distribution.get_variable_layout(self)
41-
from keras.src.distribution import TensorLayout
36+
if self._layout is None:
37+
distribution = global_state.get_global_attribute("distribution")
38+
if distribution is not None:
39+
tensor_layout = distribution.get_variable_layout(self)
40+
from keras.src.distribution import TensorLayout
41+
42+
if isinstance(tensor_layout, TensorLayout):
43+
self._layout = tensor_layout.backend_layout
44+
else:
45+
self._layout = tensor_layout
4246

43-
if isinstance(tensor_layout, TensorLayout):
44-
self._layout = tensor_layout.backend_layout
45-
else:
46-
self._layout = tensor_layout
47+
def _initialize(self, value):
48+
# Note that variable.shape is needed by distribution_lib
49+
self._shape = self._validate_shape(value.shape)
50+
self.set_tensor_layout()
4751
self._direct_assign(value)
4852

53+
def check_distributed_init(self, initializer):
54+
# Check if 'layout' parameter is supported in the initializer call
55+
import inspect
56+
57+
sig = inspect.signature(initializer.__call__)
58+
layout_supported = "layout" in sig.parameters
59+
# Check if PartitionSpec has any non-None values
60+
spec = getattr(self._layout, "spec", None)
61+
partition_spec = spec if spec is not None else ()
62+
is_partitioned = any(dim is not None for dim in partition_spec)
63+
return layout_supported and is_partitioned
64+
65+
def _initialize_with_initializer(self, initializer):
66+
self.set_tensor_layout()
67+
# Use layout-aware initialization for distributed embeddings
68+
if self.check_distributed_init(initializer):
69+
value = self._convert_to_tensor(
70+
initializer(self._shape, dtype=self._dtype, layout=self._layout)
71+
)
72+
else:
73+
value = self._convert_to_tensor(
74+
initializer(self._shape, dtype=self._dtype)
75+
)
76+
self._initialize(value)
77+
4978
def _direct_assign(self, value):
5079
if self._layout is not None:
5180
value = distribution_lib.distribute_variable(value, self._layout)

keras/src/backend/jax/distribution_lib.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,110 @@
99
from keras.src.utils import rng_utils
1010

1111

12+
def _distribute_initializer(
13+
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
14+
):
15+
"""
16+
Distribution-aware initializer for JAX backend.
17+
This function will create a Jax random array and
18+
distribute it according to the current layout.
19+
Args:
20+
init_func: A functools.partial-wrapped object that takes the seed
21+
as argument and returns a jax.Array. Must have shape and dtype
22+
already bound via partial.
23+
mean: Mean of distribution (applied to normal/truncated_normal).
24+
stddev: Standard deviation of the distribution.
25+
seed: JAX compatible seed array, if None use the Seed generator.
26+
layout: TensorLayout for the distributed tensor.
27+
Returns:
28+
A distributed jax array.
29+
Raises:
30+
ValueError: If init_func or seed is None.
31+
If init_func.func is not a supported random function.
32+
Supported jax.random func: normal, truncated_normal, uniform
33+
TypeError: If init_func is not a functools.partial object
34+
or seed is not a Jax array.
35+
36+
"""
37+
import warnings
38+
from functools import partial
39+
40+
# Draw seed from the seed generator if seed is not a Jax Array
41+
if seed is None or not isinstance(seed, jax.Array):
42+
jax_compatible_seed = seed_generator.draw_seed(None)
43+
# Convert to JAX PRNG key format (swap counter and seed value)
44+
seed = jax_compatible_seed[::-1]
45+
46+
# Validate all required arguments
47+
if init_func is None or init_func.func.__name__ not in [
48+
"normal",
49+
"truncated_normal",
50+
"uniform",
51+
]:
52+
raise ValueError(
53+
"init_func cannot be None or "
54+
"Unsupported initializer: {init_func.func.__name__}."
55+
"only JAX-compatible random initializers are supported. "
56+
"Supported jax.random funcs: normal, truncated_normal, uniform"
57+
)
58+
59+
# Ensure init_func is a partial
60+
if not isinstance(init_func, partial):
61+
raise TypeError(
62+
f"init_func must be functools.partial object, got {type(init_func)}"
63+
"init_func is a jax.random.* function with shape and "
64+
"dtype bound via partial"
65+
)
66+
67+
# Shard based on tensor layout
68+
if layout is None:
69+
warnings.warn(
70+
f"The layout is {layout}, sharding will default to single device"
71+
)
72+
73+
sharding = None
74+
else:
75+
if not isinstance(layout, jax.sharding.NamedSharding):
76+
from keras.src.distribution import TensorLayout
77+
78+
if isinstance(layout, TensorLayout):
79+
layout = _to_backend_layout(layout)
80+
else:
81+
raise TypeError(
82+
f"layout must be Keras TensorLayout or "
83+
f"jax.sharding.NamedSharding, got {type(layout)}"
84+
)
85+
sharding = layout
86+
87+
# JAX PRNG key handling within JIT:
88+
# The key is passed directly to jax.random.* functions which are
89+
# JIT-compatible and functional. JAX automatically ensures different
90+
# random values per shard when out_shardings is specified.
91+
try:
92+
compiled_init = jax.jit(
93+
lambda seed: init_func(seed),
94+
out_shardings=sharding,
95+
)
96+
sample = compiled_init(seed)
97+
98+
except RuntimeError as e:
99+
warnings.warn(
100+
f"Sharding at initialization failed due to: {e}, "
101+
f"falling back to single device"
102+
)
103+
compiled_init = jax.jit(
104+
lambda seed: init_func(seed),
105+
out_shardings=None,
106+
)
107+
sample = compiled_init(seed)
108+
109+
# Apply mean/stddev only for distributions where it makes sense
110+
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
111+
return sample * stddev + mean
112+
elif init_func.func == jax.random.uniform:
113+
return sample
114+
115+
12116
def list_devices(device_type=None):
13117
"""Return all the available devices based on the device type.
14118

keras/src/backend/jax/random.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import jax
24

35
from keras.src.backend.config import floatx
@@ -7,25 +9,61 @@
79

810

911
def jax_draw_seed(seed):
12+
# Convert to JAX PRNG key format (swap counter and seed value)
1013
if isinstance(seed, jax.Array):
11-
return seed
14+
return seed[::-1]
1215
else:
13-
return draw_seed(seed)
16+
seed_array = draw_seed(seed)
17+
return seed_array[::-1]
1418

1519

16-
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
20+
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None):
1721
dtype = dtype or floatx()
1822
seed = jax_draw_seed(seed)
19-
sample = jax.random.normal(seed, shape=shape, dtype=dtype)
20-
return sample * stddev + mean
23+
if layout is not None:
24+
from keras.src.backend import distribution_lib
25+
26+
init_func = partial(
27+
jax.random.normal,
28+
shape=shape,
29+
dtype=dtype,
30+
)
31+
return distribution_lib._distribute_initializer(
32+
init_func=init_func,
33+
mean=mean,
34+
stddev=stddev,
35+
seed=seed,
36+
layout=layout,
37+
)
38+
else:
39+
sample = jax.random.normal(seed, shape=shape, dtype=dtype)
40+
return sample * stddev + mean
2141

2242

23-
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
43+
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None):
2444
dtype = dtype or floatx()
2545
seed = jax_draw_seed(seed)
26-
return jax.random.uniform(
27-
seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval
28-
)
46+
if layout is not None:
47+
from keras.src.backend import distribution_lib
48+
49+
init_func = partial(
50+
jax.random.uniform,
51+
shape=shape,
52+
dtype=dtype,
53+
minval=minval,
54+
maxval=maxval,
55+
)
56+
return distribution_lib._distribute_initializer(
57+
init_func=init_func,
58+
mean=None,
59+
stddev=None,
60+
seed=seed,
61+
layout=layout,
62+
)
63+
else:
64+
return jax.random.uniform(
65+
seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval
66+
)
2967

3068

3169
def categorical(logits, num_samples, dtype="int32", seed=None):
@@ -46,13 +84,33 @@ def randint(shape, minval, maxval, dtype="int32", seed=None):
4684
)
4785

4886

49-
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
87+
def truncated_normal(
88+
shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None
89+
):
5090
dtype = dtype or floatx()
5191
seed = jax_draw_seed(seed)
52-
sample = jax.random.truncated_normal(
53-
seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype
54-
)
55-
return sample * stddev + mean
92+
if layout is not None:
93+
from keras.src.backend import distribution_lib
94+
95+
init_func = partial(
96+
jax.random.truncated_normal,
97+
shape=shape,
98+
dtype=dtype,
99+
lower=-2.0,
100+
upper=2.0,
101+
)
102+
return distribution_lib._distribute_initializer(
103+
init_func=init_func,
104+
mean=mean,
105+
stddev=stddev,
106+
seed=seed,
107+
layout=layout,
108+
)
109+
else:
110+
sample = jax.random.truncated_normal(
111+
seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype
112+
)
113+
return sample * stddev + mean
56114

57115

58116
def _get_concrete_noise_shape(inputs, noise_shape):

keras/src/backend/numpy/random.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from keras.src.random.seed_generator import make_default_seed
88

99

10-
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
10+
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None):
1111
dtype = dtype or floatx()
1212
seed = draw_seed(seed)
1313
rng = np.random.default_rng(seed)
1414
return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype)
1515

1616

17-
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
17+
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None):
1818
dtype = dtype or floatx()
1919
seed = draw_seed(seed)
2020
rng = np.random.default_rng(seed)
@@ -40,7 +40,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None):
4040
return output
4141

4242

43-
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
43+
def truncated_normal(
44+
shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None
45+
):
4446
dtype = dtype or floatx()
4547
seed = draw_seed(seed)
4648
rng = np.random.default_rng(seed)

keras/src/backend/openvino/random.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from keras.src.random.seed_generator import make_default_seed
1313

1414

15-
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
15+
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None):
1616
dtype = dtype or floatx()
1717
seed = draw_seed(seed)
1818
rng = np.random.default_rng(seed.data)
1919
normal_const = rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype)
2020
return OpenVINOKerasTensor(ov_opset.constant(normal_const).output(0))
2121

2222

23-
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
23+
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None):
2424
dtype = dtype or floatx()
2525
seed_val = draw_seed(seed)
2626
if isinstance(seed_val, OpenVINOKerasTensor):
@@ -96,7 +96,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None):
9696
)
9797

9898

99-
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
99+
def truncated_normal(
100+
shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None
101+
):
100102
dtype = dtype or floatx()
101103
seed = draw_seed(seed)
102104
rng = np.random.default_rng(seed.data)

keras/src/backend/tensorflow/random.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ def _cast_seed(seed):
2020
return seed
2121

2222

23-
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
23+
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None):
2424
dtype = dtype or floatx()
2525
seed = _cast_seed(draw_seed(seed))
2626
return tf.random.stateless_normal(
2727
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
2828
)
2929

3030

31-
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
31+
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None):
3232
dtype = dtype or floatx()
3333
seed = _cast_seed(draw_seed(seed))
3434
return tf.random.stateless_uniform(
@@ -61,7 +61,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None):
6161
return tf.cast(output, dtype)
6262

6363

64-
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
64+
def truncated_normal(
65+
shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None
66+
):
6567
dtype = dtype or floatx()
6668
seed = _cast_seed(draw_seed(seed))
6769
return tf.random.stateless_truncated_normal(

0 commit comments

Comments
 (0)