diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 9e6c928e3ee..504627f6b52 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1145,7 +1145,10 @@ def compute_output_spec(self, *args, **kwargs): call_spec=call_spec, class_name=self.__class__.__name__, ) - output_shape = self.compute_output_shape(**shapes_dict) + try: + output_shape = self.compute_output_shape(**shapes_dict) + except NotImplementedError as e: + return super().compute_output_spec(*args, **kwargs) if ( isinstance(output_shape, list) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index a02af992778..b117667ff81 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -1,8 +1,13 @@ import inspect +import collections +import functools +import itertools import numpy as np +import string from keras.src import backend +from keras.src import random from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.variables import is_float_dtype @@ -11,7 +16,18 @@ from keras.src.saving import serialization_lib from keras.src.utils import jax_utils from keras.src.utils import tracking +from keras.src import ops from keras.src.utils.module_utils import jax +from keras.src.utils.module_utils import tensorflow as tf + + +def standardize_pytree_collections(pytree): + if isinstance(pytree, collections.abc.Mapping): + return {k: standardize_pytree_collections(v) for k, v in pytree.items()} + elif isinstance(pytree, collections.abc.Sequence): + return [standardize_pytree_collections(v) for v in pytree] + else: + return pytree @keras_export("keras.layers.JaxLayer") @@ -196,6 +212,9 @@ def my_haiku_module_fn(inputs, training): init_fn: the function to call to initialize the model. See description above for the list of arguments it takes and the outputs it returns. If `None`, then `params` and/or `state` must be provided. + compute_output_shape_fn: Function that takes the input shape + (a tuple or nested structure of tuples) and returns the output + shape (a tuple or nested structure of tuples). params: A `PyTree` containing all the model trainable parameters. This allows passing trained parameters or controlling the initialization. If both `params` and `state` are `None`, `init_fn` is called at @@ -214,15 +233,16 @@ def __init__( self, call_fn, init_fn=None, + compute_output_shape_fn=None, params=None, state=None, seed=None, **kwargs, ): - if backend.backend() != "jax": + if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( - "JaxLayer is only supported with the JAX backend. Current " - f"backend: {backend.backend()}" + "JaxLayer is only supported with the JAX or Tensorflow backend" + f". Current backend: {backend.backend()}" ) if init_fn is None and params is None and state is None: @@ -233,7 +253,10 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.seed_generator = backend.random.SeedGenerator(seed) + self.compute_output_shape_fn = compute_output_shape_fn + if seed is None: + seed = random.seed_generator.make_default_seed() + self.jax_rng = jax.random.PRNGKey(seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: @@ -252,6 +275,10 @@ def __init__( init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"} ) + # Attributes for jax2tf functions + self.jax2tf_training_false_fn = None + self.jax2tf_training_true_fn = None + def _validate_signature(self, fn, fn_name, allowed, required): fn_parameters = inspect.signature(fn).parameters for parameter_name in required: @@ -272,6 +299,79 @@ def _validate_signature(self, fn, fn_name, allowed, required): return parameter_names + def _get_jax2tf_input_shape(self, input_shape): + """Convert input shape in a format suitable for `jax2tf`. + + `jax2tf` expects a letter for each unknown dimension, which allows + correlated dimensions. Since correlated dimensions are not supported by + Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We + however use 'batch' for dimension 0 if not defined to correlate the + batch size across inputs. + + Example (spaces added for readability): + ``` + input_shape: (None , 4 , None, None, 5 ) + result: "(batch, 4 , a , b , 5 )" + ``` + + Args: + input_shape: a single shape or a structure of shapes for the inputs. + Returns: + the shape or shapes structure in the `jax2tf` format as strings. + """ + dim_names = itertools.chain( + string.ascii_lowercase, # a, b, ... z + itertools.starmap( # aa, ab, ... az, ba, bb, ... zz + lambda a, b: a + b, + itertools.product(string.ascii_lowercase, repeat=2), + ), + ) + + def get_single_jax2tf_shape(shape): + jax2tf_shape = [] + + for index, dim in enumerate(shape): + if dim is not None: + jax2tf_shape.append(str(dim)) + elif index == 0: + jax2tf_shape.append("batch") + else: + jax2tf_shape.append(next(dim_names)) + + return "(" + ", ".join(jax2tf_shape) + ")" + + res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape) + return res + + def _jax2tf_convert(self, fn, polymorphic_shapes): + from jax.experimental import jax2tf + + converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes) + # Autograph won't work with the output of jax2tf. + converted_fn = tf.autograph.experimental.do_not_convert(converted_fn) + return converted_fn + + def _partial_with_positional(self, fn, index, value): + """Return a new partial with one positional argument set to a value. + + This is needed because `jax2tf` only supports positional arguments and + `functools.partial` only supports setting positional arguments starting + from the left. Our use case is the `training` argument which is + typically the righmost argument. + + Args: + fn: the function to wrap. + index: the index of the positional argument to set to `value`. + value: the value for the positional argument at `index`. + """ + + @functools.wraps(fn) + def wrapper(*args): + args = args[0:index] + (value,) + args[index:] + return fn(*args) + + return wrapper + @tracking.no_automatic_dependency_tracking def _create_variables(self, values, trainable): """Create a structure of variables from a structure of JAX arrays. @@ -296,14 +396,18 @@ def _create_variables(self, values, trainable): def create_variable(value): if backend.is_tensor(value) or isinstance( - value, (np.ndarray, np.generic) + value, (np.ndarray, np.generic, jax.Array) ): dtype = value.dtype if is_float_dtype(dtype): dtype = None # Use the layer dtype policy return self.add_weight( value.shape, - initializer=value, + initializer=( + backend.convert_to_tensor(value) + if value is not None + else None + ), dtype=dtype, trainable=trainable, ) @@ -328,8 +432,15 @@ def create_variable(value): else: self.state = variables - flat_variables, _ = jax.tree_util.tree_flatten(variables) - return flat_variables + if backend.backend() == "jax": + flat_variables, _ = jax.tree_util.tree_flatten(variables) + return flat_variables + elif backend.backend() == "tensorflow": + return variables + + def _split_jax_rng(self): + self.jax_rng, subkey = jax.random.split(self.jax_rng) + return subkey def _get_init_rng(self): """ @@ -343,7 +454,7 @@ def _get_init_rng(self): a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as the `rng` argument of `init_fn`. """ - return self.seed_generator.next() + return self._split_jax_rng() def _get_call_rng(self, training): """ @@ -359,23 +470,23 @@ def _get_call_rng(self, training): the `rng` argument of `call_fn`. """ if training: - return self.seed_generator.next() + return self._split_jax_rng() else: return None - def build(self, input_shape): - if self.params is not None or self.state is not None: - return - - if jax_utils.is_in_jax_tracing_scope(): + def _initialize_weights(self, input_shape): + if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): # This exception is not actually shown, it is caught and a detailed # warning about calling 'build' is printed. - raise ValueError("'JaxLayer' cannot be built in tracing scope") + raise ValueError( + "'JaxLayer' cannot be built in tracing scope" + "or inside tf function" + ) # Initialize `params` and `state` if needed by calling `init_fn`. def create_input(shape): shape = [d if d is not None else 1 for d in shape] - return jax.numpy.ones(shape) + return ops.ones(shape) init_inputs = tree.map_shape_structure(create_input, input_shape) init_args = [] @@ -398,6 +509,45 @@ def create_input(shape): ) self.tracked_state = self._create_variables(init_state, trainable=False) + def build(self, input_shape): + if self.params is None and self.state is None: + self._initialize_weights(input_shape) + + if backend.backend() == "tensorflow": + polymorphic_shapes = [] + for argument in self.call_fn_arguments: + if argument == "inputs": + polymorphic_shapes.append( + self._get_jax2tf_input_shape(input_shape) + ) + elif argument != "training": + # params, state, rng + polymorphic_shapes.append("...") + + if "training" in self.call_fn_arguments: + training_argument_index = self.call_fn_arguments.index( + "training" + ) + self.jax2tf_training_false_fn = self._jax2tf_convert( + self._partial_with_positional( + self.call_fn, training_argument_index, False + ), + polymorphic_shapes, + ) + self.jax2tf_training_true_fn = self._jax2tf_convert( + self._partial_with_positional( + self.call_fn, training_argument_index, True + ), + polymorphic_shapes, + ) + else: + self.jax2tf_training_false_fn = self._jax2tf_convert( + self.call_fn, + polymorphic_shapes, + ) + self.jax2tf_training_true_fn = None + super().build(input_shape) + def call(self, inputs, training=False): def unwrap_variable(variable): return None if variable is None else variable.value @@ -417,7 +567,8 @@ def unwrap_variable(variable): elif argument_name == "inputs": call_args.append(inputs) elif argument_name == "training": - call_args.append(training) + if backend.backend() == "jax": + call_args.append(training) def assign_state_to_variable(value, variable): # This exists only to make debugging this error case easier. @@ -429,14 +580,39 @@ def assign_state_to_variable(value, variable): ) variable.assign(value) - if self.has_state: - predictions, new_state = self.call_fn(*call_args) - jax.tree_util.tree_map( - assign_state_to_variable, new_state, self.state - ) - return predictions + def call_with_fn(fn): + if self.has_state: + predictions, new_state = fn(*call_args) + if backend.backend() == "jax": + jax.tree_util.tree_map( + assign_state_to_variable, new_state, self.state + ) + elif backend.backend() == "tensorflow": + jax.tree_util.tree_map( + assign_state_to_variable, + standardize_pytree_collections(new_state), + standardize_pytree_collections(self.state), + ) + return predictions + else: + return fn(*call_args) + + if backend.backend() == "jax": + return call_with_fn(self.call_fn) + elif backend.backend() == "tensorflow": + if self.jax2tf_training_true_fn is None: + return call_with_fn(self.jax2tf_training_false_fn) + else: + if training: + return call_with_fn(self.jax2tf_training_true_fn) + else: + return call_with_fn(self.jax2tf_training_false_fn) + + def compute_output_shape(self, input_shape): + if self.compute_output_shape_fn: + return self.compute_output_shape_fn(input_shape) else: - return self.call_fn(*call_args) + return super().compute_output_shape(input_shape) def get_config(self): config = { @@ -549,6 +725,7 @@ def my_flax_module_wrapper(module, inputs, training): def __init__( self, module, + compute_output_shape_fn=None, method=None, variables=None, **kwargs, @@ -556,12 +733,6 @@ def __init__( # Late import to only require Flax when this is used. from flax.core import scope as flax_scope - if backend.backend() != "jax": - raise ValueError( - "FlaxLayer is only supported with the JAX backend. Current " - f"backend: {backend.backend()}" - ) - self.module = module self.method = method @@ -618,6 +789,7 @@ def init_without_training(rng, inputs): super().__init__( call_fn=call_fn, init_fn=init_fn, + compute_output_shape_fn=compute_output_shape_fn, params=params, state=state, **kwargs, @@ -650,13 +822,13 @@ def _variables_to_params_and_state(self, variables): def _get_init_rng(self): return { - "params": self.seed_generator.next(), - "dropout": self.seed_generator.next(), + "params": self._split_jax_rng(), + "dropout": self._split_jax_rng(), } def _get_call_rng(self, training): if training: - return {"dropout": self.seed_generator.next()} + return {"dropout": self._split_jax_rng()} else: return {} diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 009ecd402e5..52fee536c65 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -19,6 +19,8 @@ from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer +from keras.src import ops +from keras.src import random try: import flax @@ -69,6 +71,11 @@ def jax_stateful_apply(params, state, inputs, training): return outputs, state +@object_registration.register_keras_serializable() +def stateless_compute_output_shape(input_shape): + return (input_shape[0], num_classes) + + if flax is not None: @object_registration.register_keras_serializable() @@ -179,8 +186,8 @@ def from_config(cls, config): @pytest.mark.skipif( - backend.backend() != "jax", - reason="JaxLayer and FlaxLayer are only supported with JAX backend", + backend.backend() not in ["jax", "tensorflow"], + reason="JaxLayer and FlaxLayer are only supported with JAX and TF backend", ) class TestJaxLayer(testing.TestCase): def _test_layer( @@ -194,16 +201,18 @@ def _test_layer( non_trainable_params, ): # Fake MNIST data - x_train = np.random.uniform(size=(320, 28, 28, 1)) - y_train = np.eye(num_classes, dtype="int32")[ - (np.random.uniform(size=(320,)) * num_classes).astype("int32") - ] - x_test = np.random.uniform(size=(32, 28, 28, 1)) + x_train = random.uniform(shape=(320, 28, 28, 1)) + y_train_indices = ops.cast( + ops.random.uniform(shape=(320,), minval=0, maxval=num_classes), + dtype="int32", + ) + y_train = ops.one_hot(y_train_indices, num_classes, dtype="int32") + x_test = random.uniform(shape=(32, 28, 28, 1)) def _count_params(weights): count = 0 for weight in weights: - count = count + np.prod(weight.shape) + count = count + ops.prod(ops.shape(weight)) return count def verify_weights_and_params(layer): @@ -257,7 +266,7 @@ def verify_weights_and_params(layer): for before, after in zip(ntw1_before_fit, ntw1_after_fit): self.assertNotAllClose(before, after) - expected_ouput_shape = (x_test.shape[0], num_classes) + expected_ouput_shape = (ops.shape(x_test)[0], num_classes) output1 = model1(x_test) self.assertEqual(output1.shape, expected_ouput_shape) predict1 = model1.predict(x_test, steps=1) @@ -478,7 +487,12 @@ def create_wrapper(**kwargs): flax_model = flax_model_class() if flax_model_method: kwargs["method"] = getattr(flax_model, flax_model_method) - return FlaxLayer(flax_model_class(), **kwargs) + if backend.backend() == "jax": + return FlaxLayer(flax_model_class(), **kwargs) + elif backend.backend() == "tensorflow": + return FlaxLayer( + flax_model, stateless_compute_output_shape, **kwargs + ) self._test_layer( flax_model_class.__name__,