|
9 | 9 | from keras.src.utils import rng_utils |
10 | 10 |
|
11 | 11 |
|
| 12 | +def _distribute_initializer( |
| 13 | + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None |
| 14 | +): |
| 15 | + """ |
| 16 | + Distribution-aware initializer for JAX backend. |
| 17 | + This function will create a Jax random array and |
| 18 | + distribute it according to the current layout. |
| 19 | + Args: |
| 20 | + init_func: A functools.partial-wrapped object that takes the seed |
| 21 | + as argument and returns a jax.Array. Must have shape and dtype |
| 22 | + already bound via partial. |
| 23 | + mean: Mean of distribution (applied to normal/truncated_normal). |
| 24 | + stddev: Standard deviation of the distribution. |
| 25 | + seed: JAX compatible seed array, if None use the Seed generator. |
| 26 | + layout: TensorLayout for the distributed tensor. |
| 27 | + Returns: |
| 28 | + A distributed jax array. |
| 29 | + Raises: |
| 30 | + ValueError: If init_func or seed is None. |
| 31 | + If init_func.func is not a supported random function. |
| 32 | + Supported jax.random func: normal, truncated_normal, uniform |
| 33 | + TypeError: If init_func is not a functools.partial object |
| 34 | + or seed is not a Jax array. |
| 35 | +
|
| 36 | + """ |
| 37 | + import warnings |
| 38 | + from functools import partial |
| 39 | + |
| 40 | + # Draw seed from the seed generator if seed is not a Jax Array |
| 41 | + if seed is None or not isinstance(seed, jax.Array): |
| 42 | + jax_compatible_seed = seed_generator.draw_seed(None) |
| 43 | + # Convert to JAX PRNG key format (swap counter and seed value) |
| 44 | + seed = jax_compatible_seed[::-1] |
| 45 | + |
| 46 | + # Validate all required arguments |
| 47 | + if init_func is None or init_func.func.__name__ not in [ |
| 48 | + "normal", |
| 49 | + "truncated_normal", |
| 50 | + "uniform", |
| 51 | + ]: |
| 52 | + raise ValueError( |
| 53 | + "init_func cannot be None or " |
| 54 | + "Unsupported initializer: {init_func.func.__name__}." |
| 55 | + "only JAX-compatible random initializers are supported. " |
| 56 | + "Supported jax.random funcs: normal, truncated_normal, uniform" |
| 57 | + ) |
| 58 | + |
| 59 | + # Ensure init_func is a partial |
| 60 | + if not isinstance(init_func, partial): |
| 61 | + raise TypeError( |
| 62 | + f"init_func must be functools.partial object, got {type(init_func)}" |
| 63 | + "init_func is a jax.random.* function with shape and " |
| 64 | + "dtype bound via partial" |
| 65 | + ) |
| 66 | + |
| 67 | + # Shard based on tensor layout |
| 68 | + if layout is None: |
| 69 | + warnings.warn( |
| 70 | + f"The layout is {layout}, sharding will default to single device" |
| 71 | + ) |
| 72 | + |
| 73 | + sharding = None |
| 74 | + else: |
| 75 | + if not isinstance(layout, jax.sharding.NamedSharding): |
| 76 | + from keras.src.distribution import TensorLayout |
| 77 | + |
| 78 | + if isinstance(layout, TensorLayout): |
| 79 | + layout = _to_backend_layout(layout) |
| 80 | + else: |
| 81 | + raise TypeError( |
| 82 | + f"layout must be Keras TensorLayout or " |
| 83 | + f"jax.sharding.NamedSharding, got {type(layout)}" |
| 84 | + ) |
| 85 | + sharding = layout |
| 86 | + |
| 87 | + # JAX PRNG key handling within JIT: |
| 88 | + # The key is passed directly to jax.random.* functions which are |
| 89 | + # JIT-compatible and functional. JAX automatically ensures different |
| 90 | + # random values per shard when out_shardings is specified. |
| 91 | + try: |
| 92 | + compiled_init = jax.jit( |
| 93 | + lambda seed: init_func(seed), |
| 94 | + out_shardings=sharding, |
| 95 | + ) |
| 96 | + sample = compiled_init(seed) |
| 97 | + |
| 98 | + except RuntimeError as e: |
| 99 | + warnings.warn( |
| 100 | + f"Sharding at initialization failed due to: {e}, " |
| 101 | + f"falling back to single device" |
| 102 | + ) |
| 103 | + compiled_init = jax.jit( |
| 104 | + lambda seed: init_func(seed), |
| 105 | + out_shardings=None, |
| 106 | + ) |
| 107 | + sample = compiled_init(seed) |
| 108 | + |
| 109 | + # Apply mean/stddev only for distributions where it makes sense |
| 110 | + if init_func.func in (jax.random.normal, jax.random.truncated_normal): |
| 111 | + return sample * stddev + mean |
| 112 | + elif init_func.func == jax.random.uniform: |
| 113 | + return sample |
| 114 | + |
| 115 | + |
12 | 116 | def list_devices(device_type=None): |
13 | 117 | """Return all the available devices based on the device type. |
14 | 118 |
|
|
0 commit comments