diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index aacee7818e..142b43ab56 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -99,6 +99,12 @@ from keras_hub.src.models.gemma3.gemma3_image_converter import ( Gemma3ImageConverter as Gemma3ImageConverter, ) +from keras_hub.src.models.gemma3n.gemma3n_audio_converter import ( + Gemma3nAudioConverter as Gemma3nAudioConverter, +) +from keras_hub.src.models.gemma3n.gemma3n_image_converter import ( + Gemma3nImageConverter as Gemma3nImageConverter, +) from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import ( HGNetV2ImageConverter as HGNetV2ImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 650487dcb1..8d5a749450 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -312,6 +312,18 @@ from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( Gemma3VisionEncoder as Gemma3VisionEncoder, ) +from keras_hub.src.models.gemma3n.gemma3n_backbone import ( + Gemma3nBackbone as Gemma3nBackbone, +) +from keras_hub.src.models.gemma3n.gemma3n_causal_lm import ( + Gemma3nCausalLM as Gemma3nCausalLM, +) +from keras_hub.src.models.gemma3n.gemma3n_causal_lm_preprocessor import ( + Gemma3nCausalLMPreprocessor as Gemma3nCausalLMPreprocessor, +) +from keras_hub.src.models.gemma3n.gemma3n_tokenizer import ( + Gemma3nTokenizer as Gemma3nTokenizer, +) from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone from keras_hub.src.models.gpt2.gpt2_causal_lm import ( GPT2CausalLM as GPT2CausalLM, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index b155d0e6e1..d4692cb8c8 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -41,6 +41,9 @@ from keras_hub.src.models.gemma3.gemma3_tokenizer import ( Gemma3Tokenizer as Gemma3Tokenizer, ) +from keras_hub.src.models.gemma3n.gemma3n_tokenizer import ( + Gemma3nTokenizer as Gemma3nTokenizer, +) from keras_hub.src.models.gpt2.gpt2_tokenizer import ( GPT2Tokenizer as GPT2Tokenizer, ) diff --git a/keras_hub/src/models/gemma3n/__init__.py b/keras_hub/src/models/gemma3n/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/gemma3n/gemma3n_attention.py b/keras_hub/src/models/gemma3n/gemma3n_attention.py new file mode 100644 index 0000000000..9b49163ac6 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_attention.py @@ -0,0 +1,715 @@ +import math + +import keras +import numpy as np + +from keras_hub.src.models.gemma3n.gemma3n_utils import apply_rotary_pos_emb +from keras_hub.src.models.gemma3n.gemma3n_utils import eager_attention_forward +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nAudioRelativePositionEmbedding(keras.layers.Layer): + """A layer for learning relative position embeddings for audio sequences. + + This layer implements the relative position embedding mechanism used in the + audio tower of the Gemma3n model. It computes position-aware attention + scores by generating a timing signal based on relative positions between + queries and keys, which is then projected and added to the content-based + attention logits. + + Args: + hidden_size: int. The size of the hidden state. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_context_left: int. The number of steps to attend to in + the past, including the current step. + conf_attention_context_right: int. The number of steps to attend to in + the future. + """ + + def __init__( + self, + hidden_size, + conf_num_attention_heads, + conf_attention_context_left, + conf_attention_context_right, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_context_right = conf_attention_context_right + self.num_heads = conf_num_attention_heads + self.channels = hidden_size + self.head_dim = self.channels // self.num_heads + self.max_backward = max(0, conf_attention_context_left - 1) + self.max_forward = conf_attention_context_right + self.pos_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="pos_proj", + dtype=self.dtype_policy, + ) + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * np.exp( + np.arange(num_timescales, dtype="float32") + * -log_timescale_increment + ) + self.inv_timescales = keras.ops.expand_dims( + keras.ops.expand_dims( + keras.ops.convert_to_tensor(inv_timescales, dtype="float32"), 0 + ), + 0, + ) + + def build(self, input_shape): + if not self.pos_proj.built: + self.pos_proj.build((None, self.channels)) + super().build(input_shape) + + def _get_timing_signal_1d_pos(self, position, dtype): + position = keras.ops.cast( + keras.ops.expand_dims(position, axis=-1), "float32" + ) + pos_shape = keras.ops.shape(position) + inv_shape = keras.ops.shape(self.inv_timescales) + target_shape = (pos_shape[0], pos_shape[1], inv_shape[2]) + position = keras.ops.broadcast_to(position, target_shape) + inv_timescales = keras.ops.broadcast_to( + self.inv_timescales, target_shape + ) + scaled_time = position * inv_timescales + timing_signal = keras.ops.concatenate( + [keras.ops.sin(scaled_time), keras.ops.cos(scaled_time)], axis=-1 + ) + return keras.ops.cast(timing_signal, dtype) + + def _relative_shift( + self, + term_bd_before_shift, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ): + msp1_val = max_span_plus_1 + kcs_val = key_context_size + if not isinstance(msp1_val, int) and hasattr(msp1_val, "shape"): + msp1_val = keras.ops.shape(msp1_val)[-1] + if not isinstance(kcs_val, int) and hasattr(kcs_val, "shape"): + kcs_val = keras.ops.shape(kcs_val)[-1] + pad_amount_last_dim = (kcs_val + 1) - msp1_val + padding_tuple = [[0, 0]] * ( + len(keras.ops.shape(term_bd_before_shift)) - 1 + ) + [[0, pad_amount_last_dim]] + term_bd_padded = keras.ops.pad(term_bd_before_shift, padding_tuple) + shape_padded = keras.ops.shape(term_bd_padded) + B = shape_padded[0] + H = shape_padded[1] + U = shape_padded[2] + W = shape_padded[3] + C_plus_1 = shape_padded[4] + target_shape_1_last_dim = -1 + if W is not None and C_plus_1 is not None: + try: + target_shape_1_last_dim = W * C_plus_1 + except TypeError: + target_shape_1_last_dim = -1 + term_bd_reshaped = keras.ops.reshape( + term_bd_padded, + ( + B if B is not None else -1, + H if H is not None else -1, + U if U is not None else -1, + target_shape_1_last_dim, + ), + ) + slice_end = None + qbs_val = query_block_size + if not isinstance(qbs_val, int) and hasattr(qbs_val, "shape"): + qbs_val = keras.ops.shape(qbs_val)[0] + if qbs_val is not None and kcs_val is not None: + try: + slice_end = qbs_val * kcs_val + except TypeError: + slice_end = None + term_bd_reshaped = term_bd_reshaped[..., :slice_end] + term_bd_shifted = keras.ops.reshape( + term_bd_reshaped, + ( + B if B is not None else -1, + H if H is not None else -1, + U if U is not None else -1, + W if W is not None else -1, + kcs_val if kcs_val is not None else -1, + ), + ) + return term_bd_shifted + + def _int8_call(self, queries, keys): + original_dtype = queries.dtype + queries_calc = keras.ops.cast(queries, "float32") + keys_calc = keras.ops.cast(keys, "float32") + result_calc = self.call(queries_calc, keys_calc) + return keras.ops.cast(result_calc, original_dtype) + + def call(self, queries, keys): + batch_size = keras.ops.shape(queries)[0] + ( + _, + num_query_blocks, + query_block_size, + num_heads, + head_dim, + ) = queries.shape + _, _, key_context_size, _, _ = keys.shape + pos_indices = keras.ops.expand_dims( + keras.ops.arange( + self.max_backward, -self.max_forward - 1, -1, dtype="float32" + ), + 0, + ) + max_span_plus_1 = keras.ops.shape(pos_indices)[1] + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) + projected_sin_emb = self.pos_proj(sin_emb_timing_signal) + sin_emb = keras.ops.squeeze( + keras.ops.reshape( + projected_sin_emb, + (1, max_span_plus_1, self.num_heads, self.head_dim), + ), + axis=0, + ) + queries_p = keras.ops.transpose(queries, (0, 3, 1, 2, 4)) + keys_p_t = keras.ops.transpose(keys, (0, 3, 1, 4, 2)) + term_ac = keras.ops.matmul(queries_p, keys_p_t) + q_permuted = keras.ops.transpose(queries, (0, 3, 1, 2, 4)) + s_permuted = keras.ops.transpose(sin_emb, (1, 2, 0)) + term_bd_unshifed = keras.ops.einsum( + "bhuwd,hdf->bhuwf", q_permuted, s_permuted + ) + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) + return term_ac + term_bd_shifted + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + } + ) + return config + + +class Gemma3nTextAttention(keras.layers.Layer): + """A multi-head attention layer for text sequences. + + This layer implements the text attention mechanism for the Gemma3n model, + which is a standard multi-head attention architecture. It includes features + such as Grouped-Query Attention (GQA), RMS Normalization for query and key + states, and Rotary Position Embeddings (RoPE) to incorporate positional + information. + + Args: + hidden_size: int. The size of the hidden state. + num_attention_heads: int. The number of query attention heads. + num_key_value_heads: int. The number of key and value attention heads. + If `num_key_value_heads` is not equal to `num_attention_heads`, this + layer implements Grouped-Query Attention. + head_dim: int. The dimension of each attention head. + attention_dropout: float. Dropout probability for the attention scores. + attention_bias: bool. If `True`, dense layers for query, key, value, + and output projections will use a bias term. + rms_norm_eps: float. The epsilon value for RMS Normalization layers. + sliding_window: int, optional. The size of the sliding window for + local attention. If `None`, global attention is used. Defaults to + `None`. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_dropout, + attention_bias, + rms_norm_eps, + sliding_window=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.rms_norm_eps = rms_norm_eps + self.sliding_window = sliding_window + self.num_key_value_groups = ( + self.num_attention_heads // self.num_key_value_heads + ) + self.q_proj = keras.layers.Dense( + self.num_attention_heads * self.head_dim, + use_bias=self.attention_bias, + name="q_proj", + dtype=self.dtype_policy, + ) + self.k_proj = keras.layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="k_proj", + dtype=self.dtype_policy, + ) + self.v_proj = keras.layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="v_proj", + dtype=self.dtype_policy, + ) + self.o_proj = keras.layers.Dense( + self.hidden_size, + use_bias=self.attention_bias, + name="o_proj", + dtype=self.dtype_policy, + ) + self.q_norm = Gemma3nRMSNorm( + dim=self.head_dim, + eps=self.rms_norm_eps, + name="q_norm", + dtype=self.dtype_policy, + ) + self.k_norm = Gemma3nRMSNorm( + dim=self.head_dim, + eps=self.rms_norm_eps, + name="k_norm", + dtype=self.dtype_policy, + ) + self.v_norm = Gemma3nRMSNorm( + dim=self.head_dim, + eps=self.rms_norm_eps, + with_scale=False, + name="v_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.q_proj.build(input_shape) + self.k_proj.build(input_shape) + self.v_proj.build(input_shape) + self.o_proj.build( + input_shape[:-1] + (self.num_attention_heads * self.head_dim,) + ) + norm_shape = input_shape[:-1] + ( + self.num_attention_heads, + self.head_dim, + ) + self.q_norm.build(norm_shape) + k_norm_shape = input_shape[:-1] + ( + self.num_key_value_heads, + self.head_dim, + ) + self.k_norm.build(k_norm_shape) + self.v_norm.build(k_norm_shape) + super().build(input_shape) + + def call( + self, + hidden_states, + position_embeddings, + attention_mask, + cache=None, + cache_update_index=0, + cache_update_mask=None, + training=False, + ): + input_shape = keras.ops.shape(hidden_states)[:-1] + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states) + query_states = keras.ops.reshape( + query_states, + input_shape + (self.num_attention_heads, self.head_dim), + ) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb( + query_states, cos, sin, unsqueeze_dim=2 + ) + query_states = keras.ops.transpose(query_states, (0, 2, 1, 3)) + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + key_update = self.k_proj(hidden_states) + key_update = keras.ops.reshape( + key_update, + input_shape + (self.num_key_value_heads, self.head_dim), + ) + key_update = self.k_norm(key_update) + key_update = apply_rotary_pos_emb( + key_update, cos, sin, unsqueeze_dim=2 + ) + key_update = keras.ops.transpose(key_update, (0, 2, 1, 3)) + value_update = self.v_proj(hidden_states) + value_update = keras.ops.reshape( + value_update, + input_shape + (self.num_key_value_heads, self.head_dim), + ) + value_update = self.v_norm(value_update) + value_update = keras.ops.transpose(value_update, (0, 2, 1, 3)) + start = [0, 0, cache_update_index, 0] + if cache_update_mask is not None: + cache_update_mask = keras.ops.expand_dims( + keras.ops.expand_dims(cache_update_mask, axis=1), + axis=-1, + ) + key_original = keras.ops.slice( + key_cache, start, keras.ops.shape(key_update) + ) + value_original = keras.ops.slice( + value_cache, start, keras.ops.shape(value_update) + ) + key_update = keras.ops.where( + cache_update_mask, + key_update, + key_original, + ) + value_update = keras.ops.where( + cache_update_mask, + value_update, + value_original, + ) + key_states = keras.ops.slice_update(key_cache, start, key_update) + value_states = keras.ops.slice_update( + value_cache, start, value_update + ) + cache = keras.ops.stack((key_states, value_states), axis=1) + else: + key_states = self.k_proj(hidden_states) + key_states = keras.ops.reshape( + key_states, + input_shape + (self.num_key_value_heads, self.head_dim), + ) + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb( + key_states, cos, sin, unsqueeze_dim=2 + ) + key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) + value_states = self.v_proj(hidden_states) + value_states = keras.ops.reshape( + value_states, + input_shape + (self.num_key_value_heads, self.head_dim), + ) + value_states = self.v_norm(value_states) + value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) + attn_output, attn_weights = eager_attention_forward( + query_states, + key_states, + value_states, + self.num_key_value_groups, + self.head_dim, + attention_mask, + dropout=self.attention_dropout if training else 0.0, + training=training, + ) + attn_output = keras.ops.reshape(attn_output, input_shape + (-1,)) + attn_output = self.o_proj(attn_output) + if cache is not None: + return attn_output, attn_weights, cache + return attn_output, attn_weights + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "attention_dropout": self.attention_dropout, + "attention_bias": self.attention_bias, + "rms_norm_eps": self.rms_norm_eps, + "sliding_window": self.sliding_window, + } + ) + return config + + +class Gemma3nAudioAttention(keras.layers.Layer): + """An attention layer specialized for audio sequences. + + This layer implements the attention mechanism for the audio tower of the + Gemma3n model. It is designed to handle long audio sequences by processing + the input in fixed-size chunks. For each chunk of queries, it attends to a + larger context of keys and values, defined by a left (past) and right + (future) context window. This allows the model to capture local and more + distant dependencies efficiently. + + Args: + hidden_size: int. The size of the hidden state. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_chunk_size: int. The size of each processing chunk. + conf_attention_context_right: int. The number of steps to attend to in + the future. + conf_attention_context_left: int. The number of steps to attend to in + the past, including the current step. + conf_attention_logit_cap: float. The soft cap value to apply to the + attention logits. + """ + + def __init__( + self, + hidden_size, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.num_heads = conf_num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.chunk_size = conf_attention_chunk_size + self.max_future_horizon = conf_attention_context_right + self.max_past_horizon = max(0, conf_attention_context_left - 1) + self.attention_logits_soft_cap = conf_attention_logit_cap + self.context_size = ( + self.chunk_size + self.max_past_horizon + self.max_future_horizon + ) + self.relative_position_embedding = ( + Gemma3nAudioRelativePositionEmbedding( + hidden_size, + conf_num_attention_heads, + conf_attention_context_left, + conf_attention_context_right, + name="relative_position_embedding", + dtype=self.dtype_policy, + ) + ) + self.q_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="q_proj", + dtype=self.dtype_policy, + ) + self.k_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="k_proj", + dtype=self.dtype_policy, + ) + self.v_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="v_proj", + dtype=self.dtype_policy, + ) + q_scale = self.head_dim**-0.5 + r_softplus_0 = 1.0 / np.log(1 + np.exp(0.0)) # softplus(0) for numpy + self.q_scale = q_scale * r_softplus_0 + + lower_causal_mask = np.tril( + np.ones((self.context_size, self.chunk_size), dtype=bool), k=0 + ).T + upper_causal_mask = np.tril( + np.ones((self.chunk_size, self.context_size), dtype=bool), + k=self.max_past_horizon + self.max_future_horizon, + ) + local_causal_valid_mask = np.ones( + (self.chunk_size, self.context_size), dtype=bool + ) + local_causal_valid_mask = ( + local_causal_valid_mask * lower_causal_mask * upper_causal_mask + ) + self.local_causal_valid_mask = keras.ops.convert_to_tensor( + local_causal_valid_mask + ) + self.softcap = keras.ops.convert_to_tensor( + self.attention_logits_soft_cap, dtype="float32" + ) + + def build(self, input_shape): + self.per_dim_scale = self.add_weight( + shape=(self.head_dim,), + initializer="zeros", + trainable=True, + name="per_dim_scale", + dtype=self.dtype_policy.variable_dtype, + ) + self.q_proj.build(input_shape) + self.k_proj.build(input_shape) + self.v_proj.build(input_shape) + q_build_shape = ( + None, + None, + self.chunk_size, + self.num_heads, + self.head_dim, + ) + k_build_shape = ( + None, + None, + self.context_size, + self.num_heads, + self.head_dim, + ) + self.relative_position_embedding.build((q_build_shape, k_build_shape)) + super().build(input_shape) + + def _pad_dim1(self, x, pad_left, pad_right): + paddings = [[0, 0], [pad_left, pad_right]] + [ + [0, 0] for _ in range(len(keras.ops.shape(x)) - 2) + ] + return keras.ops.pad(x, paddings) + + def _convert_to_block(self, hidden_states): + b, t = keras.ops.shape(hidden_states)[:2] + tail_shape_list = list(hidden_states.shape[2:]) + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + padding_len = num_blocks * self.chunk_size - t + hidden_states = self._pad_dim1(hidden_states, 0, padding_len) + permute_dims = [b, num_blocks, self.chunk_size] + tail_shape_list + return keras.ops.reshape(hidden_states, permute_dims) + + def _extract_block_context(self, hidden_states): + _, t = keras.ops.shape(hidden_states)[:2] + num_frames = (t + self.chunk_size - 1) // self.chunk_size + pad_left = self.max_past_horizon + pad_right = self.max_future_horizon + self.chunk_size - 1 + hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right) + frame_len = self.context_size + frame_step = self.chunk_size + + start_indices = keras.ops.arange(0, num_frames) * frame_step + frame_offsets = keras.ops.arange(0, frame_len) + indices = keras.ops.expand_dims( + start_indices, axis=1 + ) + keras.ops.expand_dims(frame_offsets, axis=0) + return keras.ops.take(hidden_states, indices, axis=1) + + def call(self, hidden_states, mask): + qkv_shape = keras.ops.shape(hidden_states)[:-1] + ( + self.num_heads, + self.head_dim, + ) + query_states = keras.ops.reshape(self.q_proj(hidden_states), qkv_shape) + key_states = keras.ops.reshape(self.k_proj(hidden_states), qkv_shape) + value_states = keras.ops.reshape(self.v_proj(hidden_states), qkv_shape) + per_dim_scale_sp = keras.ops.softplus(self.per_dim_scale) + query_states = query_states * self.q_scale * per_dim_scale_sp + batch_size, q_time = keras.ops.shape(query_states)[:2] + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = keras.ops.shape(query_blocks)[1] + original_valid_mask = keras.ops.logical_not(mask) + extracted_valid_mask_blocks = self._extract_block_context( + original_valid_mask + ) + mask_block_shape = keras.ops.shape(extracted_valid_mask_blocks) + if len(mask_block_shape) > 3: + axes_to_squeeze = [ + i + for i, dim in enumerate(mask_block_shape) + if i > 0 and i < len(mask_block_shape) - 1 and dim == 1 + ] + if axes_to_squeeze: + extracted_valid_mask_blocks = keras.ops.squeeze( + extracted_valid_mask_blocks, axis=axes_to_squeeze + ) + mask_block_shape = keras.ops.shape(extracted_valid_mask_blocks) + if ( + len(mask_block_shape) == 4 + and mask_block_shape[2] * mask_block_shape[3] == self.context_size + ): + extracted_valid_mask_blocks = keras.ops.reshape( + extracted_valid_mask_blocks, + (batch_size, num_query_blocks, self.context_size), + ) + condition_from_input_validity = keras.ops.expand_dims( + keras.ops.expand_dims(extracted_valid_mask_blocks, 1), -2 + ) + condition_from_causality = keras.ops.expand_dims( + keras.ops.expand_dims( + keras.ops.expand_dims(self.local_causal_valid_mask, 0), 0 + ), + 0, + ) + final_condition_for_where = keras.ops.logical_and( + condition_from_input_validity, + keras.ops.cast(condition_from_causality, "bool"), + ) + logits = self.relative_position_embedding(query_blocks, key_blocks) + softcap = keras.ops.cast(self.softcap, dtype=logits.dtype) + logits = logits / softcap + logits = keras.ops.tanh(logits) + logits = logits * softcap + compute_dtype = logits.dtype + dtype_str = str(compute_dtype) + if "float16" in dtype_str or "bfloat16" in dtype_str: + min_val = np.finfo(np.float16).min + else: + min_val = np.finfo(np.float32).min + min_val = keras.ops.convert_to_tensor(min_val, dtype=compute_dtype) + logits = keras.ops.where(final_condition_for_where, logits, min_val) + probabilities = keras.ops.softmax( + keras.ops.cast(logits, "float32"), axis=-1 + ) + probabilities = keras.ops.cast(probabilities, value_blocks.dtype) + context_vectors = keras.ops.einsum( + "bnuwc,bucnh->buwnh", probabilities, value_blocks + ) + context_vectors = keras.ops.reshape( + context_vectors, + ( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ), + ) + context_vectors = context_vectors[:, :q_time] + return context_vectors + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_audio_converter.py b/keras_hub/src/models/gemma3n/gemma3n_audio_converter.py new file mode 100644 index 0000000000..7f94271bf0 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_audio_converter.py @@ -0,0 +1,473 @@ +import math + +import keras +import numpy as np + +try: + import tensorflow as tf +except ImportError: + tf = None +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.layers.Gemma3nAudioConverter") +class Gemma3nAudioConverter(keras.layers.Layer): + """Converts raw audio waveforms into log-mel spectrograms. + + This layer preprocesses 1D audio signals into 2D log-mel spectrograms + suitable for the Gemma3n audio encoder. The conversion process involves + padding or truncating the raw audio to a consistent length, applying + optional dithering, input scaling, and preemphasis, and then computing the + Short-Time Fourier Transform (STFT) with a Hann window. The resulting + magnitude spectrogram is converted to the mel scale using a mel filterbank, + after which the log-mel spectrogram is calculated by taking the logarithm. + Finally, the layer can optionally normalize these features using provided + per-bin mean and standard deviation statistics, and it returns both the + spectrogram and an attention mask indicating which frames are valid. + + Args: + feature_size: int. The number of mel bins to generate. + Defaults to 128. + sampling_rate: int. The expected sampling rate of the input audio. + Defaults to 16000. + padding_value: float. The value to use for padding the raw audio. + Defaults to 0.0. + return_attention_mask: bool. Whether to return an attention mask. + Defaults to True. + frame_length_ms: float. The length of each STFT frame in + milliseconds. Defaults to 32.0. + hop_length_ms: float. The step size between STFT frames in + milliseconds. Defaults to 10.0. + min_frequency: float. The lowest frequency for the mel filterbank. + Defaults to 125.0. + max_frequency: float. The highest frequency for the mel filterbank. + Defaults to 7600.0. + preemphasis: float. The coefficient for the preemphasis filter. + Set to 0.0 to disable. Defaults to 0.97. + preemphasis_htk_flavor: bool. Whether to use the HTK-style + preemphasis. Defaults to True. + fft_overdrive: bool. If True, doubles the FFT length. + Defaults to True. + dither: float. Amount of dithering to add to the waveform. + Set to 0.0 to disable. Defaults to 0.0. + input_scale_factor: float. Factor to scale the input waveform by. + Defaults to 1.0. + mel_floor: float. A minimum value (floor) to apply before taking + the logarithm. Defaults to 1e-5. + per_bin_mean: list or None. A list of mean values for each mel + bin, used for normalization. Defaults to None. + per_bin_stddev: list or None. A list of standard deviation values + for each mel bin, used for normalization. Defaults to None. + padding_side: str. Which side to pad the audio on ('right' or + 'left'). Defaults to 'right'. + """ + + def __init__( + self, + feature_size=128, + sampling_rate=16000, + padding_value=0.0, + return_attention_mask=True, + frame_length_ms=32.0, + hop_length_ms=10.0, + min_frequency=125.0, + max_frequency=7600.0, + preemphasis=0.97, + preemphasis_htk_flavor=True, + fft_overdrive=True, + dither=0.0, + input_scale_factor=1.0, + mel_floor=1e-5, + per_bin_mean=None, + per_bin_stddev=None, + padding_side="right", + **kwargs, + ): + # === Config === + super().__init__(**kwargs) + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.return_attention_mask = return_attention_mask + self.padding_side = padding_side + self.min_frequency = min_frequency + self.max_frequency = max_frequency + self.preemphasis = preemphasis + self.preemphasis_htk_flavor = preemphasis_htk_flavor + self.fft_overdrive = fft_overdrive + self.dither = dither + self.input_scale_factor = input_scale_factor + self.frame_length_ms = frame_length_ms + self.hop_length_ms = hop_length_ms + self.mel_floor_arg = mel_floor + self.per_bin_mean_arg = per_bin_mean + self.per_bin_stddev_arg = per_bin_stddev + self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0)) + self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0)) + self.mel_floor = tf.constant(mel_floor, dtype=self.compute_dtype) + fft_length = 2 ** math.ceil(math.log2(self.frame_length)) + if self.fft_overdrive: + fft_length *= 2 + self.fft_length = fft_length + hann_arange = tf.range(self.frame_length, dtype=self.compute_dtype) + self.window = 0.5 * ( + 1 - tf.cos(2 * np.pi * hann_arange / self.frame_length) + ) + self.mel_filters = self._create_fb_matrix( + n_freqs=self.fft_length // 2 + 1, + f_min=min_frequency, + f_max=max_frequency, + n_mels=feature_size, + sample_rate=self.sampling_rate, + fft_length=fft_length, + ) + if per_bin_mean is not None: + self.per_bin_mean = tf.constant( + per_bin_mean, + shape=(1, 1, feature_size), + dtype=self.compute_dtype, + ) + else: + self.per_bin_mean = None + if per_bin_stddev is not None: + self.per_bin_stddev = tf.constant( + per_bin_stddev, + shape=(1, 1, feature_size), + dtype=self.compute_dtype, + ) + else: + self.per_bin_stddev = None + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.built = True + + def _create_fb_matrix( + self, + n_freqs, + f_min, + f_max, + n_mels, + sample_rate, + fft_length, + ): + all_freqs = tf.cast(tf.range(n_freqs), dtype=self.compute_dtype) * ( + sample_rate / fft_length + ) + m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) + m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) + m_pts = np.linspace(m_min, m_max, n_mels + 2, dtype=np.float32) + f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) + f_pts = tf.constant(f_pts, dtype=self.compute_dtype) + f_diff = f_pts[1:] - f_pts[:-1] + slopes = tf.expand_dims(f_pts, 0) - tf.expand_dims(all_freqs, 1) + zero = tf.zeros(1, dtype=self.compute_dtype) + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] + up_slopes = slopes[:, 2:] / f_diff[1:] + fb = tf.maximum(zero, tf.minimum(down_slopes, up_slopes)) + return tf.constant(fb, dtype=self.compute_dtype) + + def _extract_spectrogram(self, waveform, attention_mask): + waveform = tf.cast(waveform, dtype=self.compute_dtype) + if self.dither > 0.0: + waveform = waveform + self.dither * tf.random.normal( + tf.shape(waveform), dtype=waveform.dtype + ) + if self.input_scale_factor != 1.0: + waveform = waveform * self.input_scale_factor + if self.preemphasis > 0.0: + if self.preemphasis_htk_flavor: + first_sample = waveform[:, :1] * (1.0 - self.preemphasis) + rest_of_samples = ( + waveform[:, 1:] - self.preemphasis * waveform[:, :-1] + ) + waveform = tf.concat([first_sample, rest_of_samples], axis=-1) + else: + waveform = tf.concat( + [ + waveform[:, :1], + waveform[:, 1:] - self.preemphasis * waveform[:, :-1], + ], + axis=-1, + ) + frames = tf.signal.frame( + waveform, + frame_length=self.frame_length, + frame_step=self.hop_length, + pad_end=False, + ) + frames = frames * self.window + pad_length = self.fft_length - self.frame_length + paddings = [[0, 0], [0, 0], [0, pad_length]] + frames = tf.pad(frames, paddings) + stft = tf.signal.rfft(frames) + magnitude_spec = tf.abs(stft) + mel_spec = tf.matmul(magnitude_spec, self.mel_filters) + log_mel_spec = tf.math.log(tf.maximum(mel_spec, self.mel_floor)) + if self.per_bin_mean is not None: + log_mel_spec = log_mel_spec - self.per_bin_mean + if self.per_bin_stddev is not None: + log_mel_spec = log_mel_spec / self.per_bin_stddev + mel_spectrogram = tf.squeeze(log_mel_spec, axis=0) + mask = tf.cast(attention_mask[:: self.hop_length], dtype=tf.bool) + return mel_spectrogram, mask[: tf.shape(mel_spectrogram)[0]] + + def _get_padding_strategies(self, padding=False, max_length=None): + if padding is not False: + if padding is True: + padding_strategy = "longest" + else: + padding_strategy = padding + else: + padding_strategy = "do_not_pad" + if max_length is None: + if padding_strategy == "max_length": + raise ValueError( + "When setting padding='max_length', max_length must be " + "defined" + ) + if padding_strategy != "do_not_pad" and (self.padding_value is None): + raise ValueError("Padding requested but no padding_value defined") + return padding_strategy + + def _pad( + self, + input_features, + attention_mask=None, + max_length=None, + padding_strategy="do_not_pad", + pad_to_multiple_of=None, + return_attention_mask=None, + ): + required_input = input_features + if padding_strategy == "longest": + max_length = len(required_input) + if ( + max_length is not None + and pad_to_multiple_of is not None + and (max_length % pad_to_multiple_of != 0) + ): + max_length = ( + (max_length // pad_to_multiple_of) + 1 + ) * pad_to_multiple_of + needs_to_be_padded = ( + padding_strategy != "do_not_pad" + and len(required_input) < max_length + ) + if return_attention_mask and attention_mask is None: + attention_mask = np.ones(len(required_input), dtype=np.int32) + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + attention_mask = np.pad(attention_mask, (0, difference)) + if required_input.ndim > 1: + padding_shape = ((0, difference), (0, 0)) + else: + padding_shape = ((0, difference),) + input_features = np.pad( + required_input, + padding_shape, + "constant", + constant_values=self.padding_value, + ) + elif self.padding_side == "left": + if return_attention_mask: + attention_mask = np.pad(attention_mask, (difference, 0)) + if required_input.ndim > 1: + padding_shape = ((difference, 0), (0, 0)) + else: + padding_shape = ((difference, 0),) + input_features = np.pad( + required_input, + padding_shape, + "constant", + constant_values=self.padding_value, + ) + return input_features, attention_mask + + def _truncate( + self, + input_features, + attention_mask=None, + max_length=None, + pad_to_multiple_of=None, + truncation=None, + ): + if not truncation: + return input_features, attention_mask + elif truncation and max_length is None: + raise ValueError( + "When setting truncation=True, max_length must be defined" + ) + required_input = input_features + if ( + max_length is not None + and pad_to_multiple_of is not None + and (max_length % pad_to_multiple_of != 0) + ): + max_length = ( + (max_length // pad_to_multiple_of) + 1 + ) * pad_to_multiple_of + needs_to_be_truncated = len(required_input) > max_length + if needs_to_be_truncated: + input_features = input_features[:max_length] + if attention_mask is not None: + attention_mask = attention_mask[:max_length] + return input_features, attention_mask + + def pad( + self, + input_features, + padding=True, + max_length=None, + truncation=False, + pad_to_multiple_of=None, + return_attention_mask=None, + ): + required_input = input_features + return_attention_mask = ( + return_attention_mask + if return_attention_mask is not None + else self.return_attention_mask + ) + if len(required_input) == 0: + return [], [] if return_attention_mask else None + required_input = [np.asarray(v) for v in required_input] + padding_strategy = self._get_padding_strategies( + padding=padding, max_length=max_length + ) + batch_size = len(required_input) + truncated_inputs = [] + truncated_masks = [] + for i in range(batch_size): + inputs = required_input[i] + mask = ( + np.ones(len(inputs), dtype=np.int32) + if return_attention_mask + else None + ) + inputs_slice, mask_slice = self._truncate( + inputs, + attention_mask=mask, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + truncation=truncation, + ) + truncated_inputs.append(inputs_slice) + if mask_slice is not None: + truncated_masks.append(mask_slice) + if padding_strategy == "longest": + max_length = max( + len(input_slice) for input_slice in truncated_inputs + ) + padding_strategy = "max_length" + batch_outputs_features = [] + batch_outputs_masks = [] + for i in range(batch_size): + inputs = truncated_inputs[i] + mask = truncated_masks[i] if return_attention_mask else None + outputs_features, outputs_mask = self._pad( + inputs, + attention_mask=mask, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + if outputs_features.dtype == np.dtype(np.float64): + outputs_features = outputs_features.astype(np.float32) + batch_outputs_features.append(outputs_features) + if outputs_mask is not None: + batch_outputs_masks.append(outputs_mask) + if not return_attention_mask: + return batch_outputs_features, None + return batch_outputs_features, batch_outputs_masks + + def call( + self, + raw_speech, + padding="longest", + max_length=480000, + truncation=True, + pad_to_multiple_of=128, + return_attention_mask=True, + ): + def _process_in_py(raw_speech_tensor): + raw_speech_np = raw_speech_tensor.numpy() + is_batched = raw_speech_np.ndim > 1 + if is_batched: + speech_list = [rs.reshape(-1, 1) for rs in raw_speech_np] + else: + raw_speech_np = np.atleast_1d(raw_speech_np) + speech_list = [raw_speech_np.reshape(-1, 1)] + input_features_list, attention_mask_list = self.pad( + speech_list, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + prepared_speech = [] + prepared_speech_mask = [] + for speech, mask in zip(input_features_list, attention_mask_list): + speech_tensor = tf.constant(speech.T, dtype=self.compute_dtype) + mask_tensor = tf.constant(mask, dtype=tf.int32) + features, feature_mask = self._extract_spectrogram( + speech_tensor, mask_tensor + ) + prepared_speech.append(features) + prepared_speech_mask.append(feature_mask) + input_features = tf.stack(prepared_speech) + input_features_mask = tf.stack(prepared_speech_mask) + if not is_batched: + input_features = tf.squeeze(input_features, axis=0) + input_features_mask = tf.squeeze(input_features_mask, axis=0) + return input_features, input_features_mask + + if not isinstance(raw_speech, (tf.Tensor, tf.RaggedTensor)): + was_batched = isinstance(raw_speech, (list, tuple)) + raw_speech = tf.convert_to_tensor( + raw_speech, dtype=self.compute_dtype + ) + else: + was_batched = raw_speech.shape.rank > 1 + input_features, input_features_mask = tf.py_function( + _process_in_py, + inp=[raw_speech], + Tout=[self.compute_dtype, tf.bool], + ) + num_frames = None + if was_batched: + input_features.set_shape([None, num_frames, self.feature_size]) + input_features_mask.set_shape([None, num_frames]) + else: + input_features.set_shape([num_frames, self.feature_size]) + input_features_mask.set_shape([num_frames]) + input_features_mask = tf.cast(input_features_mask, dtype="int32") + return input_features, input_features_mask + + def get_config(self): + config = super().get_config() + config.update( + { + "feature_size": self.feature_size, + "sampling_rate": self.sampling_rate, + "padding_value": self.padding_value, + "return_attention_mask": self.return_attention_mask, + "frame_length_ms": self.frame_length_ms, + "hop_length_ms": self.hop_length_ms, + "min_frequency": self.min_frequency, + "max_frequency": self.max_frequency, + "preemphasis": self.preemphasis, + "preemphasis_htk_flavor": self.preemphasis_htk_flavor, + "fft_overdrive": self.fft_overdrive, + "dither": self.dither, + "input_scale_factor": self.input_scale_factor, + "mel_floor": self.mel_floor_arg, + "per_bin_mean": self.per_bin_mean_arg, + "per_bin_stddev": self.per_bin_stddev_arg, + "padding_side": self.padding_side, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_audio_converter_test.py b/keras_hub/src/models/gemma3n/gemma3n_audio_converter_test.py new file mode 100644 index 0000000000..78dd912a0f --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_audio_converter_test.py @@ -0,0 +1,123 @@ +import numpy as np + +from keras_hub.src.models.gemma3n.gemma3n_audio_converter import ( + Gemma3nAudioConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class Gemma3nAudioConverterTest(TestCase): + def setUp(self): + super().setUp() + self.feature_size = 128 + self.sampling_rate = 16000 + self.hop_length_ms = 10.0 + self.frame_length_ms = 32.0 + # Dummy audio. + self.input_data = [ + np.sin( + 2 + * np.pi + * 440 + * np.linspace(0, 1, self.sampling_rate, dtype=np.float32) + ) + ] + self.init_kwargs = { + "feature_size": self.feature_size, + "sampling_rate": self.sampling_rate, + "padding_value": 0.0, + "return_attention_mask": True, + "frame_length_ms": self.frame_length_ms, + "hop_length_ms": self.hop_length_ms, + "min_frequency": 125.0, + "max_frequency": 7600.0, + "preemphasis": 0.97, + "preemphasis_htk_flavor": True, + "fft_overdrive": True, + "dither": 0.0, + "input_scale_factor": 1.0, + "mel_floor": 1e-5, + "per_bin_mean": None, + "per_bin_stddev": None, + "padding_side": "right", + } + + def test_output_shape(self): + converter = Gemma3nAudioConverter(**self.init_kwargs) + outputs = converter(self.input_data[0]) + frame_length = int( + round(self.sampling_rate * self.frame_length_ms / 1000.0) + ) + hop_length = int( + round(self.sampling_rate * self.hop_length_ms / 1000.0) + ) + num_frames = (len(self.input_data[0]) - frame_length) // hop_length + 1 + expected_features_shape = (num_frames, self.feature_size) + expected_mask_shape = (num_frames,) + # Check that the outputs are tuples with two elements. + self.assertIsInstance(outputs, tuple) + self.assertEqual(len(outputs), 2) + input_features, input_features_mask = outputs + self.assertEqual(input_features.shape, expected_features_shape) + self.assertEqual(input_features_mask.shape, expected_mask_shape) + + def test_padding(self): + max_length = 20000 + pad_to_multiple_of = 128 + converter = Gemma3nAudioConverter(**self.init_kwargs) + outputs = converter( + self.input_data[0], + padding="max_length", + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + ) + # Calculate expectations. + if max_length % pad_to_multiple_of != 0: + padded_length = ( + (max_length // pad_to_multiple_of) + 1 + ) * pad_to_multiple_of + else: + padded_length = max_length + frame_length = int( + round(self.sampling_rate * self.frame_length_ms / 1000.0) + ) + hop_length = int( + round(self.sampling_rate * self.hop_length_ms / 1000.0) + ) + num_frames = (padded_length - frame_length) // hop_length + 1 + expected_features_shape = (num_frames, self.feature_size) + # Check that the outputs are tuples with two elements. + self.assertIsInstance(outputs, tuple) + self.assertEqual(len(outputs), 2) + input_features, _ = outputs + self.assertEqual(input_features.shape, expected_features_shape) + + def test_normalization(self): + mean = np.random.rand(self.feature_size).tolist() + stddev = np.random.rand(self.feature_size).tolist() + # One converter with normalization and one without. + converter_no_norm = Gemma3nAudioConverter(**self.init_kwargs) + norm_kwargs = self.init_kwargs.copy() + norm_kwargs["per_bin_mean"] = mean + norm_kwargs["per_bin_stddev"] = stddev + converter_norm = Gemma3nAudioConverter(**norm_kwargs) + outputs_no_norm = converter_no_norm(self.input_data) + outputs_norm = converter_norm(self.input_data) + # Check that the outputs are tuples with two elements. + self.assertIsInstance(outputs_no_norm, tuple) + self.assertEqual(len(outputs_no_norm), 2) + self.assertIsInstance(outputs_norm, tuple) + self.assertEqual(len(outputs_norm), 2) + features_no_norm, _ = outputs_no_norm + features_norm, _ = outputs_norm + # We would want outputs to be different. + self.assertNotAllClose(features_no_norm, features_norm) + # Manually normalize and check for closeness. + manual_norm_features = (features_no_norm - np.array(mean)) / np.array( + stddev + ) + self.assertAllClose(manual_norm_features, features_norm) + + def test_serialization(self): + instance = Gemma3nAudioConverter(**self.init_kwargs) + self.run_serialization_test(instance=instance) diff --git a/keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py b/keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py new file mode 100644 index 0000000000..f1c8167a31 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py @@ -0,0 +1,580 @@ +import keras + +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioConformerAttention, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioConformerFeedForward, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioConformerLightConv1d, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioSSCPConvBlock, +) +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nAudioSubSampleConvProjection(keras.layers.Layer): + """A convolutional projection layer that subsamples audio features. + + This layer applies two blocks of 2D convolutions to the input audio + spectrogram. Each block subsamples the input along the time and frequency + dimensions. The output is then flattened and projected to the model's + hidden size. + + Args: + input_feat_size: int. The number of frequency bins in the input + spectrogram. + hidden_size: int. The dimensionality of the output embeddings. + sscp_conv_channel_size: list of int. The number of output channels for + each of the two convolutional blocks. + sscp_conv_kernel_size: list of tuple of int. The kernel sizes for each + of the two convolutional blocks. + sscp_conv_stride_size: list of tuple of int. The stride sizes for each + of the two convolutional blocks. + sscp_conv_group_norm_eps: float. Epsilon value for the Group + Normalization layers within the convolutional blocks. + """ + + def __init__( + self, + input_feat_size, + hidden_size, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.input_feat_size = input_feat_size + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + current_f_for_block_input = input_feat_size + self.calculated_block_padding = [] + self.calculated_f_out_dims = [] + for i in range(2): + kernel_h, kernel_w = sscp_conv_kernel_size[i] + _, stride_w = sscp_conv_stride_size[i] + pad_t_top, pad_t_bottom, pad_f_left, pad_f_right = ( + 0, + kernel_h - 1, + 1, + 1, + ) + manual_padding_tuple = ( + pad_f_left, + pad_f_right, + pad_t_top, + pad_t_bottom, + ) + self.calculated_block_padding.append(manual_padding_tuple) + f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right + f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 + self.calculated_f_out_dims.append(f_out_after_conv) + current_f_for_block_input = f_out_after_conv + self.conv_0 = Gemma3nAudioSSCPConvBlock( + idx=0, + input_freq_dim=input_feat_size, + sscp_conv_channel_size=sscp_conv_channel_size, + sscp_conv_kernel_size=sscp_conv_kernel_size, + sscp_conv_stride_size=sscp_conv_stride_size, + sscp_conv_group_norm_eps=sscp_conv_group_norm_eps, + manual_padding=self.calculated_block_padding[0], + name="conv_0", + dtype=self.dtype_policy, + ) + self.conv_1 = Gemma3nAudioSSCPConvBlock( + idx=1, + name="conv_1", + input_freq_dim=self.calculated_f_out_dims[0], + sscp_conv_channel_size=sscp_conv_channel_size, + sscp_conv_kernel_size=sscp_conv_kernel_size, + sscp_conv_stride_size=sscp_conv_stride_size, + sscp_conv_group_norm_eps=sscp_conv_group_norm_eps, + manual_padding=self.calculated_block_padding[1], + dtype=self.dtype_policy, + ) + self.input_proj_linear = keras.layers.Dense( + hidden_size, + use_bias=False, + name="input_proj_linear", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + _, t_in, f_in = input_shape + conv0_input_shape = (None, 1, t_in, f_in) + self.conv_0.build(conv0_input_shape) + if t_in is not None: + pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0][2:4] + kernel_h_0, _ = self.sscp_conv_kernel_size[0] + stride_h_0, _ = self.sscp_conv_stride_size[0] + t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0 + t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1 + else: + t_out_0 = None + c_out_0 = self.sscp_conv_channel_size[0] + f_out_0 = self.calculated_f_out_dims[0] + conv1_input_shape = (None, c_out_0, t_out_0, f_out_0) + self.conv_1.build(conv1_input_shape) + if t_out_0 is not None: + t_padded_1 = ( + t_out_0 + + self.calculated_block_padding[1][2] + + self.calculated_block_padding[1][3] + ) + kernel_h_1, _ = self.sscp_conv_kernel_size[1] + stride_h_1, _ = self.sscp_conv_stride_size[1] + t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1 + else: + t_out_1 = None + c_out_1 = self.sscp_conv_channel_size[1] + f_out_1 = self.calculated_f_out_dims[1] + proj_input_shape = (None, t_out_1, f_out_1 * c_out_1) + self.input_proj_linear.build(proj_input_shape) + super().build(input_shape) + + def compute_output_shape(self, input_shape): + b, t_in, f_in = input_shape + if t_in is not None: + _, _, pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0] + kernel_h_0, _ = self.sscp_conv_kernel_size[0] + stride_h_0, _ = self.sscp_conv_stride_size[0] + t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0 + t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1 + _, _, pad_t_top_1, pad_t_bottom_1 = self.calculated_block_padding[1] + kernel_h_1, _ = self.sscp_conv_kernel_size[1] + stride_h_1, _ = self.sscp_conv_stride_size[1] + t_padded_1 = t_out_0 + pad_t_top_1 + pad_t_bottom_1 + t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1 + else: + t_out_1 = None + return (b, t_out_1, self.hidden_size) + + def call(self, audio_encodings): + audio_encodings_reshaped = keras.ops.expand_dims(audio_encodings, 1) + x = self.conv_0(audio_encodings_reshaped) + x = self.conv_1(x) + b, c_out, t_out, f_out = keras.ops.shape(x) + x_permuted = keras.ops.transpose(x, (0, 2, 3, 1)) + output_flattened = keras.ops.reshape( + x_permuted, (b, t_out, f_out * c_out) + ) + return self.input_proj_linear(output_flattened) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_feat_size": self.input_feat_size, + "hidden_size": self.hidden_size, + "sscp_conv_channel_size": self.sscp_conv_channel_size, + "sscp_conv_kernel_size": self.sscp_conv_kernel_size, + "sscp_conv_stride_size": self.sscp_conv_stride_size, + "sscp_conv_group_norm_eps": self.sscp_conv_group_norm_eps, + } + ) + return config + + +class Gemma3nAudioConformerBlock(keras.layers.Layer): + """A single conformer block for processing audio sequences. + + This layer implements the conformer architecture, which consists of a + sequence of four modules: a feed-forward module, a multi-head + self-attention module, a convolution module, and a final feed-forward + module. The output of each module is added to its input through a residual + connection. + + Args: + hidden_size: int. The dimensionality of the input and output embeddings. + rms_norm_eps: float. Epsilon value for the Gemma 3n RMS normalization + layers. + gradient_clipping: float. The maximum absolute value for the gradient. + conf_residual_weight: float. The weight for the residual connection in + the feed-forward layers. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_chunk_size: int. The size of chunks for local attention. + conf_attention_context_right: int. The right context size for local + attention. + conf_attention_context_left: int. The left context size for local + attention. + conf_attention_logit_cap: float. The maximum value for the attention + logits. + conf_conv_kernel_size: int. The kernel size for the 1D convolution + layer. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + gradient_clipping, + conf_residual_weight, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + conf_conv_kernel_size, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.gradient_clipping = gradient_clipping + self.conf_residual_weight = conf_residual_weight + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.conf_conv_kernel_size = conf_conv_kernel_size + self.ffw_layer_start = Gemma3nAudioConformerFeedForward( + hidden_size=hidden_size, + gradient_clipping=gradient_clipping, + conf_residual_weight=conf_residual_weight, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="ffw_layer_start", + ) + self.attention = Gemma3nAudioConformerAttention( + hidden_size=hidden_size, + gradient_clipping=gradient_clipping, + conf_num_attention_heads=conf_num_attention_heads, + conf_attention_chunk_size=conf_attention_chunk_size, + conf_attention_context_right=conf_attention_context_right, + conf_attention_context_left=conf_attention_context_left, + conf_attention_logit_cap=conf_attention_logit_cap, + dtype=self.dtype_policy, + name="attention", + ) + self.lconv1d = Gemma3nAudioConformerLightConv1d( + hidden_size=hidden_size, + rms_norm_eps=rms_norm_eps, + conf_conv_kernel_size=conf_conv_kernel_size, + gradient_clipping=gradient_clipping, + dtype=self.dtype_policy, + name="lconv1d", + ) + self.ffw_layer_end = Gemma3nAudioConformerFeedForward( + hidden_size=hidden_size, + gradient_clipping=gradient_clipping, + conf_residual_weight=conf_residual_weight, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="ffw_layer_end", + ) + self.norm = Gemma3nRMSNorm( + hidden_size, eps=rms_norm_eps, name="norm", dtype=self.dtype_policy + ) + + def build(self, input_shape): + if ( + isinstance(input_shape, tuple) + and len(input_shape) == 2 + and isinstance(input_shape[0], tuple) + ): + audio_encodings_shape, _ = input_shape + elif isinstance(input_shape, tuple) and len(input_shape) >= 3: + audio_encodings_shape = input_shape + else: + raise ValueError( + f"Unexpected `input_shape` structure for " + f"Gemma3nAudioConformerBlock: {input_shape}" + ) + self.ffw_layer_start.build(audio_encodings_shape) + self.attention.build(audio_encodings_shape) + self.lconv1d.build(audio_encodings_shape) + self.ffw_layer_end.build(audio_encodings_shape) + self.norm.build(audio_encodings_shape) + super().build(input_shape) + + def compute_output_shape(self, input_shape): + audio_encodings_shape, _ = input_shape + return audio_encodings_shape + + def call(self, inputs): + audio_encodings, audio_mel_mask = inputs + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention(audio_encodings, audio_mel_mask) + validity_mask_for_lconv = keras.ops.logical_not(audio_mel_mask) + mask_shape = keras.ops.shape(validity_mask_for_lconv) + enc_shape = keras.ops.shape(audio_encodings) + if len(mask_shape) < len(enc_shape): + validity_mask_for_lconv = keras.ops.expand_dims( + validity_mask_for_lconv, -1 + ) + audio_encodings_for_lconv_input = audio_encodings * keras.ops.cast( + validity_mask_for_lconv, + audio_encodings.dtype, + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + output = self.norm(audio_encodings) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "gradient_clipping": self.gradient_clipping, + "conf_residual_weight": self.conf_residual_weight, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + "conf_conv_kernel_size": self.conf_conv_kernel_size, + } + ) + return config + + +class Gemma3nAudioEncoder(keras.layers.Layer): + """The main audio encoder for the Gemma3n model. + + This layer combines a subsampling convolutional projection with a stack of + conformer blocks to encode audio spectrograms into a sequence of hidden + states. + + Args: + hidden_size: int. The dimensionality of the embeddings. + input_feat_size: int. The number of frequency bins in the input + spectrogram. + sscp_conv_channel_size: list of int. The number of output channels for + each of the two convolutional blocks in the subsampler. + sscp_conv_kernel_size: list of tuple of int. The kernel sizes for each + of the two convolutional blocks in the subsampler. + sscp_conv_stride_size: list of tuple of int. The stride sizes for each + of the two convolutional blocks in the subsampler. + sscp_conv_group_norm_eps: float. Epsilon value for the Group + Normalization layers in the subsampler. + conf_num_hidden_layers: int. The number of conformer blocks. + rms_norm_eps: float. Epsilon value for the Gemma 3n RMS normalization + layers. + gradient_clipping: float. The maximum absolute value for the gradient. + conf_residual_weight: float. The weight for the residual connection in + the feed-forward layers of the conformer blocks. + conf_num_attention_heads: int. The number of attention heads in the + conformer blocks. + conf_attention_chunk_size: int. The size of chunks for local attention + in the conformer blocks. + conf_attention_context_right: int. The right context size for local + attention in the conformer blocks. + conf_attention_context_left: int. The left context size for local + attention in the conformer blocks. + conf_attention_logit_cap: float. The maximum value for the attention + logits in the conformer blocks. + conf_conv_kernel_size: int. The kernel size for the 1D convolution + layer in the conformer blocks. + conf_reduction_factor: int. The factor by which to reduce the sequence + length of the final output. + """ + + def __init__( + self, + hidden_size, + input_feat_size, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + conf_num_hidden_layers, + rms_norm_eps, + gradient_clipping, + conf_residual_weight, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + conf_conv_kernel_size, + conf_reduction_factor, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.input_feat_size = input_feat_size + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + self.conf_num_hidden_layers = conf_num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.gradient_clipping = gradient_clipping + self.conf_residual_weight = conf_residual_weight + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.conf_conv_kernel_size = conf_conv_kernel_size + self.conf_reduction_factor = conf_reduction_factor + self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection( + input_feat_size, + hidden_size, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + dtype=self.dtype_policy, + name="subsample_conv_projection", + ) + self.conformer = [ + Gemma3nAudioConformerBlock( + hidden_size, + rms_norm_eps, + gradient_clipping, + conf_residual_weight, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + conf_conv_kernel_size, + dtype=self.dtype_policy, + name=f"conformer_block_{i}", + ) + for i in range(conf_num_hidden_layers) + ] + + def build(self, input_shape): + if ( + isinstance(input_shape, tuple) + and len(input_shape) == 2 + and isinstance(input_shape[0], tuple) + ): + audio_mel_shape, _ = input_shape + else: + raise ValueError( + f"Unexpected `input_shape` structure for Gemma3nAudioEncoder: " + f"{input_shape}" + ) + self.subsample_conv_projection.build(audio_mel_shape) + encodings_shape = self.subsample_conv_projection.compute_output_shape( + audio_mel_shape + ) + t_sub = encodings_shape[1] + time_stride_product = 1 + for stride_pair in self.sscp_conv_stride_size: + time_stride_product *= stride_pair[0] + batch_size = audio_mel_shape[0] + current_mask_shape = ( + (batch_size, t_sub) if t_sub is not None else (batch_size, None) + ) + current_encodings_shape = encodings_shape + for block in self.conformer: + block.build((current_encodings_shape, current_mask_shape)) + current_encodings_shape = block.compute_output_shape( + (current_encodings_shape, current_mask_shape) + ) + super().build(input_shape) + + def compute_output_shape(self, input_shape): + audio_mel_shape, _ = input_shape + encodings_shape = self.subsample_conv_projection.compute_output_shape( + audio_mel_shape + ) + t_sub = encodings_shape[1] + time_stride_product = 1 + for stride_pair in self.sscp_conv_stride_size: + time_stride_product *= stride_pair[0] + batch_size = audio_mel_shape[0] + current_mask_shape = ( + (batch_size, t_sub) if t_sub is not None else (batch_size, None) + ) + current_encodings_shape = encodings_shape + for block in self.conformer: + current_encodings_shape = block.compute_output_shape( + (current_encodings_shape, current_mask_shape) + ) + final_mask_shape = current_mask_shape + if self.conf_reduction_factor > 1: + t_sub = current_encodings_shape[1] + if t_sub is not None: + new_t = t_sub // self.conf_reduction_factor + current_encodings_shape = ( + current_encodings_shape[0], + new_t, + current_encodings_shape[2], + ) + final_mask_shape = ( + (current_mask_shape[0], new_t) + if current_mask_shape[1] is not None + else (current_mask_shape[0], None) + ) + return current_encodings_shape, final_mask_shape + + def call(self, inputs): + audio_mel, audio_mel_mask = inputs + audio_encodings = self.subsample_conv_projection(audio_mel) + t_sub = keras.ops.shape(audio_encodings)[1] + time_stride_product = 1 + for stride_pair in self.sscp_conv_stride_size: + time_stride_product *= stride_pair[0] + mask_rank = len(keras.ops.shape(audio_mel_mask)) + audio_mel_mask_to_take = audio_mel_mask + if mask_rank > 2: + audio_mel_mask_to_take = keras.ops.squeeze( + audio_mel_mask, axis=list(range(1, mask_rank - 1)) + ) + indices = keras.ops.arange(0, t_sub) * time_stride_product + indices = keras.ops.clip( + indices, 0, keras.ops.shape(audio_mel_mask_to_take)[1] - 1 + ) + current_mask = keras.ops.take(audio_mel_mask_to_take, indices, axis=1) + for block in self.conformer: + audio_encodings = block((audio_encodings, current_mask)) + + if self.conf_reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.conf_reduction_factor] + current_mask = current_mask[:, :: self.conf_reduction_factor] + mask_shape = keras.ops.shape(current_mask) + enc_shape = keras.ops.shape(audio_encodings) + if len(mask_shape) < len(enc_shape): + current_mask_expanded = keras.ops.expand_dims(current_mask, axis=-1) + else: + current_mask_expanded = current_mask + return audio_encodings * keras.ops.cast( + keras.ops.logical_not(current_mask_expanded), + audio_encodings.dtype, + ), current_mask + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "input_feat_size": self.input_feat_size, + "sscp_conv_channel_size": self.sscp_conv_channel_size, + "sscp_conv_kernel_size": self.sscp_conv_kernel_size, + "sscp_conv_stride_size": self.sscp_conv_stride_size, + "sscp_conv_group_norm_eps": self.sscp_conv_group_norm_eps, + "conf_num_hidden_layers": self.conf_num_hidden_layers, + "rms_norm_eps": self.rms_norm_eps, + "gradient_clipping": self.gradient_clipping, + "conf_residual_weight": self.conf_residual_weight, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + "conf_conv_kernel_size": self.conf_conv_kernel_size, + "conf_reduction_factor": self.conf_reduction_factor, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_audio_layers.py b/keras_hub/src/models/gemma3n/gemma3n_audio_layers.py new file mode 100644 index 0000000000..11d15813b9 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_audio_layers.py @@ -0,0 +1,526 @@ +import keras + +from keras_hub.src.models.gemma3n.gemma3n_attention import Gemma3nAudioAttention +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nAudioCumulativeGroupNorm(keras.layers.Layer): + """A cumulative group normalization layer for audio features. + + This layer normalizes the input hidden states based on cumulative statistics + calculated over the time dimension. It is designed to process audio + spectrograms or similar sequential data. + + Args: + num_channels: int. The number of channels for normalization. + feature_dims: tuple. The dimensions of the features to be normalized. + eps: float. A small epsilon value to add to the variance to avoid + division by zero. + """ + + def __init__( + self, + num_channels, + feature_dims, + eps=1e-3, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_channels = num_channels + self.feature_dims = tuple(feature_dims) + self.eps = eps + self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1)) + + def build(self, input_shape): + self.scale = self.add_weight( + shape=(self.num_channels,), + initializer="ones", + trainable=True, + name="scale", + dtype=self.dtype_policy.variable_dtype, + ) + super().build(input_shape) + + def _int8_call(self, hidden_states): + original_dtype = hidden_states.dtype + x_calc = keras.ops.cast(hidden_states, "float32") + result_calc = self.call(x_calc) + return keras.ops.cast(result_calc, original_dtype) + + def call(self, hidden_states): + input_dtype = hidden_states.dtype + x_calc = keras.ops.cast(hidden_states, "float32") + mask_calc = keras.ops.ones_like(x_calc, dtype="float32") + sum_values_at_t = keras.ops.sum( + x_calc, axis=self.reduction_axes, keepdims=True + ) + cum_sum_values = keras.ops.cumsum(sum_values_at_t, axis=1) + elements_in_group_at_t = keras.ops.sum( + mask_calc, axis=self.reduction_axes, keepdims=True + ) + cum_count_elements = keras.ops.cumsum(elements_in_group_at_t, axis=1) + safe_cum_count_elements = keras.ops.maximum(cum_count_elements, 1.0) + cum_mean = cum_sum_values / safe_cum_count_elements + squared_diff_from_mean = keras.ops.square(x_calc - cum_mean) + sum_sq_diff_at_t = keras.ops.sum( + squared_diff_from_mean, axis=self.reduction_axes, keepdims=True + ) + cum_sum_sq_diff = keras.ops.cumsum(sum_sq_diff_at_t, axis=1) + cum_variance = cum_sum_sq_diff / safe_cum_count_elements + normalized_x = (x_calc - cum_mean) * keras.ops.rsqrt( + cum_variance + self.eps + ) + scale_view_shape = [1] * (len(hidden_states.shape) - 1) + [ + self.num_channels + ] + reshaped_scale = keras.ops.reshape(self.scale, scale_view_shape) + normalized_x = normalized_x * keras.ops.cast(reshaped_scale, "float32") + final_output = normalized_x * mask_calc + return keras.ops.cast(final_output, input_dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_channels": self.num_channels, + "feature_dims": self.feature_dims, + "eps": self.eps, + } + ) + return config + + +class Gemma3nAudioSSCPConvBlock(keras.layers.Layer): + """A single SSCP (Spectrogram Sub-sampling Convolutional Preprocessor) + block. + + This block consists of a 2D convolution, a cumulative group normalization + layer, and a ReLU activation. It is used to process and downsample audio + spectrograms. + + Args: + idx: int. The index of the convolutional block. + input_freq_dim: int. The frequency dimension of the input spectrogram. + sscp_conv_channel_size: list or tuple. A sequence containing the number + of output channels for each convolutional block in the SSCP stack. + sscp_conv_kernel_size: list or tuple. A sequence of kernel sizes for + each convolutional block. + sscp_conv_stride_size: list or tuple. A sequence of stride sizes for + each convolutional block. + sscp_conv_group_norm_eps: float. The epsilon value for the cumulative + group normalization layer. + manual_padding: tuple. A tuple of 4 integers specifying the manual + padding to be applied as (pad_w_left, pad_w_right, pad_h_top, + pad_h_bottom). + """ + + def __init__( + self, + idx, + input_freq_dim, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + manual_padding=(0, 0, 0, 0), + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.idx = idx + self.input_freq_dim = input_freq_dim + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + self.manual_padding = manual_padding + out_channels = sscp_conv_channel_size[idx] + kernel_h, kernel_w = sscp_conv_kernel_size[idx] + stride_h, stride_w = sscp_conv_stride_size[idx] + self.conv = keras.layers.Conv2D( + filters=out_channels, + kernel_size=(kernel_h, kernel_w), + strides=(stride_h, stride_w), + padding="valid", + use_bias=False, + data_format="channels_last", + name="conv", + dtype=self.dtype_policy, + ) + f_in_padded = ( + input_freq_dim + self.manual_padding[0] + self.manual_padding[1] + ) + f_out_conv = (f_in_padded - kernel_w) // stride_w + 1 + self.norm = Gemma3nAudioCumulativeGroupNorm( + num_channels=out_channels, + feature_dims=(f_out_conv,), + eps=sscp_conv_group_norm_eps, + name="norm", + dtype=self.dtype_policy, + ) + self.activation = keras.layers.ReLU( + name="activation", dtype=self.dtype_policy + ) + + def build(self, input_shape): + _, c_in, h, w = input_shape + if h is not None: + padded_h = h + self.manual_padding[2] + self.manual_padding[3] + else: + padded_h = None + padded_w = w + self.manual_padding[0] + self.manual_padding[1] + conv_input_shape = (None, padded_h, padded_w, c_in) + if not self.conv.built: + self.conv.build(conv_input_shape) + if h is not None: + h_out = (padded_h - self.conv.kernel_size[0]) // self.conv.strides[ + 0 + ] + 1 + else: + h_out = None + w_out = (padded_w - self.conv.kernel_size[1]) // self.conv.strides[ + 1 + ] + 1 + norm_input_shape = (None, h_out, w_out, self.conv.filters) + if not self.norm.built: + self.norm.build(norm_input_shape) + super().build(input_shape) + + def call(self, audio_encodings): + audio_encodings_nhwc = keras.ops.transpose( + audio_encodings, (0, 2, 3, 1) + ) + keras_padding = [ + [0, 0], + [self.manual_padding[2], self.manual_padding[3]], + [self.manual_padding[0], self.manual_padding[1]], + [0, 0], + ] + audio_encodings_padded = keras.ops.pad( + audio_encodings_nhwc, + keras_padding, + mode="constant", + constant_values=0.0, + ) + audio_encodings_conv = self.conv(audio_encodings_padded) + x_normed = self.norm(audio_encodings_conv) + audio_encodings_normed = keras.ops.transpose(x_normed, (0, 3, 1, 2)) + return self.activation(audio_encodings_normed) + + def get_config(self): + config = super().get_config() + config.update( + { + "idx": self.idx, + "input_freq_dim": self.input_freq_dim, + "sscp_conv_channel_size": self.sscp_conv_channel_size, + "sscp_conv_kernel_size": self.sscp_conv_kernel_size, + "sscp_conv_stride_size": self.sscp_conv_stride_size, + "sscp_conv_group_norm_eps": self.sscp_conv_group_norm_eps, + "manual_padding": self.manual_padding, + } + ) + return config + + +class Gemma3nAudioConformerFeedForward(keras.layers.Layer): + """The feed-forward module for the Conformer block. + + This module implements the feed-forward sub-layer of a Conformer block, + which consists of pre-layer normalization, two dense layers with a SiLU + activation function in between, post-layer normalization, and a residual + connection. + + Args: + hidden_size: int. The hidden size of the input and output tensors. + gradient_clipping: float. The maximum absolute value for gradient + clipping. + conf_residual_weight: float. The weight applied to the output of the + sub-layer before adding the residual connection. + rms_norm_eps: float. The epsilon value for the RMS normalization layers. + """ + + def __init__( + self, + hidden_size, + gradient_clipping, + conf_residual_weight, + rms_norm_eps, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.gradient_clipping = gradient_clipping + self.conf_residual_weight = conf_residual_weight + self.rms_norm_eps = rms_norm_eps + self.pre_layer_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="pre_layer_norm", + dtype=self.dtype_policy, + ) + self.ffw_layer_1 = keras.layers.Dense( + hidden_size * 4, + use_bias=False, + name="ffw_layer_1", + dtype=self.dtype_policy, + ) + self.ffw_layer_2 = keras.layers.Dense( + hidden_size, + use_bias=False, + name="ffw_layer_2", + dtype=self.dtype_policy, + ) + self.post_layer_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_layer_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.pre_layer_norm.build(input_shape) + self.ffw_layer_1.build(input_shape) + ffw1_output_shape = input_shape[:-1] + (self.hidden_size * 4,) + self.ffw_layer_2.build(ffw1_output_shape) + self.post_layer_norm.build(input_shape) + super().build(input_shape) + + def call(self, audio_encodings): + residual = audio_encodings + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.ffw_layer_1(audio_encodings) + audio_encodings = keras.activations.silu(audio_encodings) + audio_encodings = self.ffw_layer_2(audio_encodings) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.conf_residual_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "gradient_clipping": self.gradient_clipping, + "conf_residual_weight": self.conf_residual_weight, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + +class Gemma3nAudioConformerLightConv1d(keras.layers.Layer): + """The lightweight 1D convolution module for the Conformer block. + + This module implements the convolution sub-layer of a Conformer block, + which consists of pre-layer normalization, a gated linear unit (GLU), a + lightweight depthwise 1D convolution, and a final projection, followed by a + residual connection. + + Args: + hidden_size: int. The hidden size of the input and output tensors. + rms_norm_eps: float. The epsilon value for the RMS normalization layers. + conf_conv_kernel_size: int. The kernel size for the depthwise 1D + convolution. + gradient_clipping: float. The maximum absolute value for gradient + clipping. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + conf_conv_kernel_size, + gradient_clipping, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.conf_conv_kernel_size = conf_conv_kernel_size + self.gradient_clipping = gradient_clipping + self.pre_layer_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="pre_layer_norm", + dtype=self.dtype_policy, + ) + self.linear_start = keras.layers.Dense( + hidden_size * 2, + use_bias=False, + name="linear_start", + dtype=self.dtype_policy, + ) + self.depthwise_conv1d = keras.layers.DepthwiseConv1D( + kernel_size=conf_conv_kernel_size, + strides=1, + padding="valid", + use_bias=False, + data_format="channels_last", + name="depthwise_conv1d", + dtype=self.dtype_policy, + ) + self.conv_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="conv_norm", + dtype=self.dtype_policy, + ) + self.linear_end = keras.layers.Dense( + hidden_size, + use_bias=False, + name="linear_end", + dtype=self.dtype_policy, + ) + self.causal_padding = conf_conv_kernel_size - 1 + + def build(self, input_shape): + self.pre_layer_norm.build(input_shape) + self.linear_start.build(input_shape) + glu_output_shape = input_shape[:-1] + (self.hidden_size,) + self.depthwise_conv1d.build(glu_output_shape) + self.conv_norm.build(glu_output_shape) + self.linear_end.build(glu_output_shape) + super().build(input_shape) + + def call(self, audio_encodings): + residual = audio_encodings + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + gated, activated = keras.ops.split(audio_encodings, 2, axis=-1) + audio_encodings = gated * keras.activations.sigmoid(activated) + + padded = keras.ops.pad( + audio_encodings, + [[0, 0], [self.causal_padding, 0], [0, 0]], + ) + audio_encodings = self.depthwise_conv1d(padded) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = keras.activations.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + return audio_encodings + residual + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "conf_conv_kernel_size": self.conf_conv_kernel_size, + "gradient_clipping": self.gradient_clipping, + } + ) + return config + + +class Gemma3nAudioConformerAttention(keras.layers.Layer): + """The attention module for the Conformer block. + + This module implements the multi-head self-attention sub-layer of a + Conformer block. It wraps the core attention mechanism with pre and post + layer normalization, a final dense projection, and a residual connection. + + Args: + hidden_size: int. The hidden size of the input and output tensors. + gradient_clipping: float. The maximum absolute value for gradient + clipping. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_chunk_size: int. The chunk size for attention + computation, used for memory efficiency. + conf_attention_context_right: int. The right context size for attention. + conf_attention_context_left: int. The left context size for attention. + conf_attention_logit_cap: float. The value to which attention logits + are capped. + """ + + def __init__( + self, + hidden_size, + gradient_clipping, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.gradient_clipping = gradient_clipping + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.pre_attn_norm = Gemma3nRMSNorm( + hidden_size, name="pre_attn_norm", dtype=self.dtype_policy + ) + self.attn = Gemma3nAudioAttention( + hidden_size, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + dtype=self.dtype_policy, + name="attn", + ) + self.post = keras.layers.Dense( + hidden_size, use_bias=False, name="post", dtype=self.dtype_policy + ) + self.post_norm = Gemma3nRMSNorm( + hidden_size, name="post_norm", dtype=self.dtype_policy + ) + + def build(self, input_shape): + self.pre_attn_norm.build(input_shape) + self.attn.build(input_shape) + self.post.build(input_shape) + self.post_norm.build(input_shape) + super().build(input_shape) + + def call(self, audio_encodings, audio_mel_mask): + residual = audio_encodings + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + audio_encodings_attn_out = self.attn( + audio_encodings_norm, audio_mel_mask + ) + b, t, num_heads, head_dim = keras.ops.shape(audio_encodings_attn_out) + audio_encodings_reshaped = keras.ops.reshape( + audio_encodings_attn_out, (b, t, num_heads * head_dim) + ) + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return residual + self.post_norm(audio_encodings) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "gradient_clipping": self.gradient_clipping, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_backbone.py b/keras_hub/src/models/gemma3n/gemma3n_backbone.py new file mode 100644 index 0000000000..1941a2b74f --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_backbone.py @@ -0,0 +1,918 @@ +import inspect + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, +) +from keras_hub.src.models.gemma3n.gemma3n_text_model import Gemma3nTextModel +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) + + +class Gemma3nMultimodalEmbedder(keras.layers.Layer): + """A layer for handling multimodal embeddings. + + This layer manages embeddings for different modalities (here, vision, text, + and audio). It can take either token IDs or pre-computed embedding vectors + as input. The embeddings are normalized and projected to match the text + model's hidden size. + + Args: + multimodal_hidden_size: int. The hidden size of the multimodal + embeddings. + text_hidden_size: int. The hidden size of the text model. + rms_norm_eps: float. The epsilon value for the Gemma 3n RMS + normalization layers. + vocab_offset: int. The vocabulary offset for the specific modality. + vocab_size: int. The vocabulary size for the specific modality. + """ + + def __init__( + self, + multimodal_hidden_size, + text_hidden_size, + rms_norm_eps, + vocab_offset, + vocab_size, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.multimodal_hidden_size = multimodal_hidden_size + self.text_hidden_size = text_hidden_size + self.rms_norm_eps = rms_norm_eps + self.vocab_offset = vocab_offset + self.vocab_size = vocab_size + self.embedding = keras.layers.Embedding( + vocab_size, + multimodal_hidden_size, + name="embedding", + dtype=self.dtype_policy, + ) + self.hard_embedding_norm = Gemma3nRMSNorm( + multimodal_hidden_size, + eps=rms_norm_eps, + name="hard_embedding_norm", + dtype=self.dtype_policy, + ) + self.soft_embedding_norm = Gemma3nRMSNorm( + multimodal_hidden_size, + eps=rms_norm_eps, + name="soft_embedding_norm", + dtype=self.dtype_policy, + ) + self.embedding_projection = keras.layers.Dense( + text_hidden_size, + use_bias=False, + name="embedding_projection", + dtype=self.dtype_policy, + ) + self.embedding_post_projection_norm = Gemma3nRMSNorm( + text_hidden_size, + eps=rms_norm_eps, + with_scale=False, + name="embedding_post_projection_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + embeds_shape = (None, None, self.multimodal_hidden_size) + self.hard_embedding_norm.build(embeds_shape) + self.soft_embedding_norm.build(embeds_shape) + self.embedding_projection.build(embeds_shape) + proj_shape = (None, None, self.text_hidden_size) + self.embedding_post_projection_norm.build(proj_shape) + self.embedding.build((None, None)) + super().build(input_shape) + + def call(self, inputs): + input_ids, inputs_embeds = None, None + if isinstance(inputs, list): + input_ids, inputs_embeds = inputs + elif "int" in str(inputs.dtype): + input_ids = inputs + else: + inputs_embeds = inputs + if (input_ids is None) and (inputs_embeds is None): + raise ValueError( + "You must specify either input_ids or inputs_embeds" + ) + if (input_ids is not None) and (inputs_embeds is not None): + raise ValueError( + "You can only specify one of input_ids or inputs_embeds" + ) + if inputs_embeds is not None: + emb_norm = self.soft_embedding_norm(inputs_embeds) + else: + index_to_lookup = input_ids - self.vocab_offset + hard_emb = self.embedding(index_to_lookup) + emb_norm = self.hard_embedding_norm(hard_emb) + + emb_norm_proj = self.embedding_projection(emb_norm) + return self.embedding_post_projection_norm(emb_norm_proj) + + def get_config(self): + config = super().get_config() + config.update( + { + "multimodal_hidden_size": self.multimodal_hidden_size, + "text_hidden_size": self.text_hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "vocab_offset": self.vocab_offset, + "vocab_size": self.vocab_size, + } + ) + return config + + +class Gemma3nMultimodalEmbeddingProcessor(keras.layers.Layer): + """Processes and interleaves text, vision, and audio embeddings. + + This layer takes raw token IDs and multimodal inputs (pixel values, audio + features) and produces a final sequence of embeddings ready for the + decoder. It handles the embedding lookup for text and special tokens, + and replaces the special tokens with the processed features from the + vision and audio encoders. + + Args: + language_model: `keras_hub.models.gemma3n.Gemma3nTextModel`. The + underlying text model containing embedding layers. + vision_encoder: `keras.Model`. The vision encoder model. + embed_vision: `keras_hub.models.gemma3n.Gemma3nMultimodalEmbedder`. The + embedder for vision. + audio_encoder: `keras_hub.models.gemma3n.Gemma3nAudioEncoder`. The audio + encoder model. + embed_audio: `keras_hub.models.gemma3n.Gemma3nMultimodalEmbedder`. The + embedder for audio. + vision_soft_tokens_per_image: int. Number of tokens to represent an + image. + audio_soft_tokens_per_image: int. Number of tokens to represent an + audio clip. + image_token_id: int. The special token ID for images. + audio_token_id: int. The special token ID for audio. + vocab_size_per_layer_input: int. The vocabulary size for per-layer + inputs. + """ + + def __init__( + self, + language_model, + vision_encoder, + embed_vision, + audio_encoder, + embed_audio, + vision_soft_tokens_per_image, + audio_soft_tokens_per_image, + image_token_id, + audio_token_id, + vocab_size_per_layer_input, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.language_model = language_model + self.vision_encoder = vision_encoder + self.embed_vision = embed_vision + self.audio_encoder = audio_encoder + self.embed_audio = embed_audio + self.vision_soft_tokens_per_image = vision_soft_tokens_per_image + self.audio_soft_tokens_per_image = audio_soft_tokens_per_image + self.image_token_id = image_token_id + self.audio_token_id = audio_token_id + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.text_hidden_size = language_model.embed_tokens.embedding_dim + + def build(self, input_shape): + super().build(input_shape) + + def compute_output_spec(self, inputs): + input_ids_spec = inputs["token_ids"] + batch_size = input_ids_spec.shape[0] + seq_len = input_ids_spec.shape[1] + inputs_embeds_spec = keras.KerasTensor( + shape=(batch_size, seq_len, self.text_hidden_size), + dtype=input_ids_spec.dtype + if hasattr(input_ids_spec.dtype, "name") + else "float32", + ) + num_layers = self.language_model.num_hidden_layers + per_layer_hidden_size = self.language_model.hidden_size_per_layer_input + per_layer_inputs_spec = keras.KerasTensor( + shape=(batch_size, seq_len, num_layers, per_layer_hidden_size), + dtype=input_ids_spec.dtype + if hasattr(input_ids_spec.dtype, "name") + else "float32", + ) + return inputs_embeds_spec, per_layer_inputs_spec + + def call(self, inputs): + input_ids = inputs["token_ids"] + pixel_values = inputs.get("pixel_values") + input_features = inputs.get("input_features") + input_features_mask = inputs.get("input_features_mask") + inputs_embeds = self.language_model.embed_tokens(input_ids) + per_layer_inputs_mask = keras.ops.logical_and( + input_ids >= 0, input_ids < self.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = keras.ops.where( + per_layer_inputs_mask, input_ids, keras.ops.zeros_like(input_ids) + ) + per_layer_inputs = self.language_model.get_per_layer_inputs( + per_layer_inputs_tokens + ) + if self.vision_encoder and self.embed_vision: + if self.embed_audio: + vision_upper_bound = self.embed_audio.vocab_offset + else: + vision_upper_bound = ( + self.embed_vision.vocab_offset + + self.embed_vision.vocab_size + ) + vision_mask = keras.ops.logical_and( + input_ids >= self.embed_vision.vocab_offset, + input_ids < vision_upper_bound, + ) + dummy_vision_token_id = ( + self.embed_vision.vocab_offset + + self.embed_vision.embedding.input_dim + - 1 + ) + vision_input_ids = keras.ops.where( + vision_mask, input_ids, dummy_vision_token_id + ) + vision_embeds_from_vocab = self.embed_vision(vision_input_ids) + expanded_vision_mask = keras.ops.expand_dims(vision_mask, axis=-1) + inputs_embeds = keras.ops.where( + expanded_vision_mask, + vision_embeds_from_vocab, + inputs_embeds, + ) + if self.audio_encoder and self.embed_audio: + audio_mask = input_ids >= self.embed_audio.vocab_offset + dummy_audio_token_id = ( + self.embed_audio.vocab_offset + + self.embed_audio.embedding.input_dim + - 1 + ) + audio_input_ids = keras.ops.where( + audio_mask, input_ids, dummy_audio_token_id + ) + audio_embeds_from_vocab = self.embed_audio(audio_input_ids) + expanded_audio_mask = keras.ops.expand_dims(audio_mask, axis=-1) + inputs_embeds = keras.ops.where( + expanded_audio_mask, audio_embeds_from_vocab, inputs_embeds + ) + + if pixel_values is not None and self.vision_encoder: + reshape_target = (-1,) + tuple(self.vision_encoder.image_shape) + pixel_values = keras.ops.reshape(pixel_values, reshape_target) + vision_features = self.vision_encoder(pixel_values) + if self.vision_encoder.data_format == "channels_first": + vision_features = keras.ops.transpose( + vision_features, (0, 2, 3, 1) + ) + shape = keras.ops.shape(vision_features) + vision_features = keras.ops.reshape( + vision_features, (shape[0], shape[1] * shape[2], shape[3]) + ) + vision_features *= keras.ops.sqrt( + keras.ops.cast( + self.vision_encoder.num_features, dtype=inputs_embeds.dtype + ) + ) + vision_embeds = self.embed_vision(vision_features) + image_token_mask = keras.ops.equal(input_ids, self.image_token_id) + + def scatter_vision_features(): + batch_size, seq_len, hidden_size = keras.ops.shape( + inputs_embeds + ) + flat_vision_embeds = keras.ops.reshape( + vision_embeds, [-1, hidden_size] + ) + flat_full_mask = keras.ops.reshape(image_token_mask, [-1]) + gather_indices = ( + keras.ops.cumsum(keras.ops.cast(flat_full_mask, "int32")) + - 1 + ) + gather_indices = keras.ops.where( + flat_full_mask, gather_indices, 0 + ) + replacement_values = keras.ops.take( + flat_vision_embeds, gather_indices, axis=0 + ) + replacement_tensor = keras.ops.reshape( + replacement_values, (batch_size, seq_len, hidden_size) + ) + expanded_full_mask = keras.ops.expand_dims( + image_token_mask, axis=-1 + ) + return keras.ops.where( + expanded_full_mask, replacement_tensor, inputs_embeds + ) + + inputs_embeds = keras.ops.cond( + keras.ops.any(image_token_mask), + scatter_vision_features, + lambda: inputs_embeds, + ) + + if ( + input_features is not None + and input_features_mask is not None + and self.audio_encoder + ): + original_shape = keras.ops.shape(input_features) + b, n, t, f = ( + original_shape[0], + original_shape[1], + original_shape[2], + original_shape[3], + ) + input_features = keras.ops.reshape(input_features, (b * n, t, f)) + input_features_mask = keras.ops.reshape( + input_features_mask, (b * n, t) + ) + audio_features, _ = self.audio_encoder( + (input_features, input_features_mask) + ) + audio_embeds = self.embed_audio(audio_features) + audio_embeds_shape = keras.ops.shape(audio_embeds) + t_out, h = audio_embeds_shape[1], audio_embeds_shape[2] + audio_embeds = keras.ops.reshape(audio_embeds, (b, n, t_out, h)) + shape = keras.ops.shape(audio_embeds) + audio_batch_size, audio_num_clips, audio_seq_len, hidden_size = ( + shape[0], + shape[1], + shape[2], + shape[3], + ) + target_len = self.audio_soft_tokens_per_image + last_audio_token_id = ( + self.embed_audio.vocab_offset + + self.embed_audio.embedding.input_dim + - 1 + ) + padding_toks = keras.ops.convert_to_tensor( + [[last_audio_token_id]], dtype="int64" + ) + padding_embs = self.embed_audio(padding_toks) + padding_token = keras.ops.squeeze(padding_embs, axis=[0]) + flat_audio_embeds = keras.ops.reshape( + audio_embeds, [-1, hidden_size] + ) + vocab = keras.ops.concatenate( + [flat_audio_embeds, padding_token], axis=0 + ) + pad_token_index = keras.ops.shape(flat_audio_embeds)[0] + indices = keras.ops.arange(target_len) + is_real_token = indices < audio_seq_len + batch_offsets = ( + keras.ops.arange(audio_batch_size * audio_num_clips) + * audio_seq_len + ) + real_indices = keras.ops.expand_dims( + indices, 0 + ) + keras.ops.expand_dims(batch_offsets, 1) + final_indices = keras.ops.where( + keras.ops.expand_dims(is_real_token, 0), + real_indices, + pad_token_index, + ) + audio_embeds = keras.ops.take(vocab, final_indices, axis=0) + audio_embeds = keras.ops.reshape( + audio_embeds, + (audio_batch_size, audio_num_clips * target_len, hidden_size), + ) + audio_token_mask = keras.ops.equal(input_ids, self.audio_token_id) + + def scatter_audio_features(): + batch_size, seq_len, hidden_size = keras.ops.shape( + inputs_embeds + ) + flat_audio_embeds = keras.ops.reshape( + audio_embeds, [-1, hidden_size] + ) + flat_full_mask = keras.ops.reshape(audio_token_mask, [-1]) + gather_indices = ( + keras.ops.cumsum(keras.ops.cast(flat_full_mask, "int32")) + - 1 + ) + gather_indices = keras.ops.where( + flat_full_mask, gather_indices, 0 + ) + replacement_values = keras.ops.take( + flat_audio_embeds, gather_indices, axis=0 + ) + replacement_tensor = keras.ops.reshape( + replacement_values, (batch_size, seq_len, hidden_size) + ) + expanded_full_mask = keras.ops.expand_dims( + audio_token_mask, axis=-1 + ) + return keras.ops.where( + expanded_full_mask, replacement_tensor, inputs_embeds + ) + + inputs_embeds = keras.ops.cond( + keras.ops.any(audio_token_mask), + scatter_audio_features, + lambda: inputs_embeds, + ) + projected_per_layer_inputs = ( + self.language_model.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + ) + return inputs_embeds, projected_per_layer_inputs + + def get_config(self): + config = super().get_config() + config.update( + { + "language_model": keras.layers.serialize(self.language_model), + "vision_encoder": keras.layers.serialize(self.vision_encoder), + "embed_vision": keras.layers.serialize(self.embed_vision), + "audio_encoder": keras.layers.serialize(self.audio_encoder), + "embed_audio": keras.layers.serialize(self.embed_audio), + "vision_soft_tokens_per_image": self.vision_soft_tokens_per_image, # noqa: E501 + "audio_soft_tokens_per_image": self.audio_soft_tokens_per_image, + "image_token_id": self.image_token_id, + "audio_token_id": self.audio_token_id, + "vocab_size_per_layer_input": self.vocab_size_per_layer_input, + } + ) + return config + + @classmethod + def from_config(cls, config): + config = config.copy() + language_model = keras.layers.deserialize(config.pop("language_model")) + vision_encoder = keras.layers.deserialize(config.pop("vision_encoder")) + embed_vision = keras.layers.deserialize(config.pop("embed_vision")) + audio_encoder = keras.layers.deserialize(config.pop("audio_encoder")) + embed_audio = keras.layers.deserialize(config.pop("embed_audio")) + return cls( + language_model=language_model, + vision_encoder=vision_encoder, + embed_vision=embed_vision, + audio_encoder=audio_encoder, + embed_audio=embed_audio, + **config, + ) + + +@keras_hub_export("keras_hub.models.Gemma3nBackbone") +class Gemma3nBackbone(Backbone): + """The Gemma3n model backbone. + + This model is a multimodal transformer that can process text, image, and + audio inputs. It consists of a text decoder and optional vision and audio + encoders. + + Args: + text_vocab_size: int. The size of the text vocabulary. + text_hidden_size: int. The hidden size of the text model. + num_hidden_layers: int. The number of hidden layers in the text model. + pad_token_id: int. The ID of the padding token. + num_attention_heads: int. The number of attention heads in the text + model. + num_key_value_heads: int. The number of key-value heads for GQA. + head_dim: int. The dimension of each attention head. + intermediate_size: list[int]. A list of intermediate sizes for the MLP + layers. + hidden_activation: str. The activation function for the MLP layers. + layer_types: list[str]. A list of layer types ('full_attention' or + 'sliding_attention'). + sliding_window: int. The sliding window size for sliding window + attention. + rope_theta: float. The theta value for RoPE. + max_position_embeddings: int. The maximum sequence length. + vocab_size_per_layer_input: int. The vocab size for per-layer inputs. + hidden_size_per_layer_input: int. The hidden size for per-layer inputs. + altup_num_inputs: int. The number of inputs for the AltUp mechanism. + laurel_rank: int. The rank for the Laurel block. + attention_bias: bool. Whether to use a bias in the attention + projections. + attention_dropout: float. The dropout rate for attention weights. + rope_scaling: float. The scaling factor for RoPE. + rope_local_base_freq: float. The base frequency for local RoPE. + activation_sparsity_pattern: list[float]. The sparsity pattern for MLP + activations. + altup_coef_clip: float. The coefficient clipping value for AltUp. + altup_active_idx: int. The active index for AltUp. + altup_correct_scale: bool. Whether to correct the scale in AltUp. + num_kv_shared_layers: int. The number of shared KV layers. + vision_encoder_config: dict. The config for the vision encoder. + vision_hidden_size: int. The hidden size of the vision embeddings. + vision_vocab_size: int. The vocabulary size for vision tokens. + vision_vocab_offset: int. The vocabulary offset for vision tokens. + vision_soft_tokens_per_image: int. The number of tokens per image. + image_token_id: int. The special token ID for images. + audio_encoder_config: dict. The config for the audio encoder. + audio_hidden_size: int. The hidden size of the audio embeddings. + audio_vocab_size: int. The vocabulary size for audio tokens. + audio_vocab_offset: int. The vocabulary offset for audio tokens. + audio_soft_tokens_per_image: int. The number of tokens per audio clip. + audio_token_id: int. The special token ID for audio. + rms_norm_eps: float. The epsilon value for RMS normalization. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. Defaults to `None`. + + Example: + ```python + import numpy as np + from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, + ) + from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone + from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, + ) + from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + convert_arch_def_to_stackwise, + ) + + # Vision encoder config. + vision_arch_def = [["er_r1_k3_s1_e1_c16"]] + stackwise_params = convert_arch_def_to_stackwise(vision_arch_def) + vision_encoder = MobileNetV5Backbone( + **stackwise_params, + num_features=4, + image_shape=(224, 224, 3), + use_msfa=False, + ) + + # Audio encoder config. + audio_encoder = Gemma3nAudioEncoder( + hidden_size=8, + input_feat_size=32, + sscp_conv_channel_size=[4, 8], + sscp_conv_kernel_size=[(3, 3), (3, 3)], + sscp_conv_stride_size=[(2, 2), (2, 2)], + sscp_conv_group_norm_eps=1e-5, + conf_num_hidden_layers=1, + rms_norm_eps=1e-6, + gradient_clipping=1.0, + conf_residual_weight=0.5, + conf_num_attention_heads=1, + conf_attention_chunk_size=4, + conf_attention_context_right=5, + conf_attention_context_left=5, + conf_attention_logit_cap=50.0, + conf_conv_kernel_size=5, + conf_reduction_factor=1, + ) + + # Backbone config. + backbone = Gemma3nBackbone( + text_vocab_size=50, + text_hidden_size=8, + num_hidden_layers=1, + pad_token_id=0, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=8, + intermediate_size=[16], + hidden_activation="gelu_approximate", + layer_types=["full_attention"], + sliding_window=4, + rope_theta=10000.0, + max_position_embeddings=16, + vocab_size_per_layer_input=50, + hidden_size_per_layer_input=2, + altup_num_inputs=2, + laurel_rank=1, + vision_encoder_config=vision_encoder.get_config(), + vision_hidden_size=16, + audio_encoder_config=audio_encoder.get_config(), + audio_hidden_size=8, + ) + + # Create dummy inputs. + input_data = { + "token_ids": np.random.randint(0, 50, size=(1, 16), dtype="int32"), + "attention_mask": np.ones((1, 1, 16, 16), dtype=bool), + "pixel_values": np.random.rand(1, 1, 224, 224, 3).astype("float32"), + "input_features": np.random.rand(1, 16, 32).astype("float32"), + "input_features_mask": np.zeros((1, 16), dtype=bool), + } + + # Forward pass. + outputs = backbone(input_data) + ``` + """ + + def __init__( + self, + text_vocab_size, + text_hidden_size, + num_hidden_layers, + pad_token_id, + num_attention_heads, + num_key_value_heads, + head_dim, + intermediate_size, + hidden_activation, + layer_types, + sliding_window, + rope_theta, + max_position_embeddings, + vocab_size_per_layer_input, + hidden_size_per_layer_input, + altup_num_inputs, + laurel_rank, + attention_bias=False, + attention_dropout=0.0, + rope_scaling=None, + rope_local_base_freq=10000.0, + activation_sparsity_pattern=None, + altup_coef_clip=None, + altup_active_idx=0, + altup_correct_scale=True, + num_kv_shared_layers=0, + vision_encoder_config=None, + vision_hidden_size=2048, + vision_vocab_size=128, + vision_vocab_offset=100, + vision_soft_tokens_per_image=256, + image_token_id=98, + audio_encoder_config=None, + audio_hidden_size=32, + audio_vocab_size=128, + audio_vocab_offset=228, + audio_soft_tokens_per_image=188, + audio_token_id=99, + rms_norm_eps=1e-6, + dtype=None, + **kwargs, + ): + # === Layers === + self.vision_encoder = None + if vision_encoder_config: + local_vision_encoder_config = vision_encoder_config.copy() + local_vision_encoder_config["dtype"] = dtype + self.vision_encoder = MobileNetV5Backbone.from_config( + local_vision_encoder_config + ) + if not self.vision_encoder.built: + input_shape = (None,) + tuple(self.vision_encoder.image_shape) + self.vision_encoder.build(input_shape) + self.audio_encoder = None + if audio_encoder_config: + audio_encoder_sig = inspect.signature(Gemma3nAudioEncoder.__init__) + audio_encoder_args = { + p.name for p in audio_encoder_sig.parameters.values() + } + keras_layer_sig = inspect.signature(keras.layers.Layer.__init__) + keras_layer_args = { + p.name for p in keras_layer_sig.parameters.values() + } + valid_args = audio_encoder_args.union(keras_layer_args) + filtered_kwargs = { + key: value + for key, value in audio_encoder_config.items() + if key in valid_args + } + filtered_kwargs.pop("dtype", None) + self.audio_encoder = Gemma3nAudioEncoder( + dtype=dtype, **filtered_kwargs + ) + if not self.audio_encoder.built: + mel_shape = ( + None, + None, + self.audio_encoder.input_feat_size, + ) + mask_shape = (None, None) + self.audio_encoder.build((mel_shape, mask_shape)) + self.language_model = Gemma3nTextModel( + pad_token_id=pad_token_id, + vocab_size=text_vocab_size, + hidden_size=text_hidden_size, + num_hidden_layers=num_hidden_layers, + rms_norm_eps=rms_norm_eps, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + sliding_window=sliding_window, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_local_base_freq=rope_local_base_freq, + max_position_embeddings=max_position_embeddings, + intermediate_size=intermediate_size, + hidden_activation=hidden_activation, + activation_sparsity_pattern=activation_sparsity_pattern, + altup_num_inputs=altup_num_inputs, + altup_coef_clip=altup_coef_clip, + altup_active_idx=altup_active_idx, + altup_correct_scale=altup_correct_scale, + laurel_rank=laurel_rank, + hidden_size_per_layer_input=hidden_size_per_layer_input, + vocab_size_per_layer_input=vocab_size_per_layer_input, + num_kv_shared_layers=num_kv_shared_layers, + dtype=dtype, + name="text_model", + ) + self.embed_vision = None + if self.vision_encoder: + self.embed_vision = Gemma3nMultimodalEmbedder( + multimodal_hidden_size=vision_hidden_size, + text_hidden_size=text_hidden_size, + rms_norm_eps=rms_norm_eps, + vocab_offset=vision_vocab_offset, + vocab_size=vision_vocab_size, + dtype=dtype, + name="vision_embedder", + ) + if not self.embed_vision.built: + self.embed_vision.build((None, None)) + self.embed_audio = None + if self.audio_encoder: + self.embed_audio = Gemma3nMultimodalEmbedder( + multimodal_hidden_size=audio_hidden_size, + text_hidden_size=text_hidden_size, + rms_norm_eps=rms_norm_eps, + vocab_offset=audio_vocab_offset, + vocab_size=audio_vocab_size, + dtype=dtype, + name="audio_embedder", + ) + if not self.embed_audio.built: + self.embed_audio.build((None, None)) + self.embedding_processor = Gemma3nMultimodalEmbeddingProcessor( + language_model=self.language_model, + vision_encoder=self.vision_encoder, + embed_vision=self.embed_vision, + audio_encoder=self.audio_encoder, + embed_audio=self.embed_audio, + vision_soft_tokens_per_image=vision_soft_tokens_per_image, + audio_soft_tokens_per_image=audio_soft_tokens_per_image, + image_token_id=image_token_id, + audio_token_id=audio_token_id, + vocab_size_per_layer_input=vocab_size_per_layer_input, + dtype=dtype, + name="multimodal_embedding_processor", + ) + + # === Functional Model === + # === Model Inputs === + token_ids_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="bool", name="padding_mask" + ) + processor_inputs = { + "token_ids": token_ids_input, + } + model_inputs_list = [token_ids_input, padding_mask_input] + model_inputs_dict = { + "token_ids": token_ids_input, + "padding_mask": padding_mask_input, + } + + # === Modality Feature Extraction and Interleaving === + if self.vision_encoder: + input_shape = (None,) + tuple(self.vision_encoder.image_shape) + images_input = keras.Input( + shape=input_shape, + dtype="float32", + name="images", + ) + processor_inputs["pixel_values"] = images_input + model_inputs_list.append(images_input) + model_inputs_dict["images"] = images_input + if self.audio_encoder: + input_features_input = keras.Input( + shape=(None, None, self.audio_encoder.input_feat_size), + dtype="float32", + name="input_features", + ) + input_features_mask_input = keras.Input( + shape=(None, None), dtype="bool", name="input_features_mask" + ) + processor_inputs["input_features"] = input_features_input + processor_inputs["input_features_mask"] = input_features_mask_input + model_inputs_list.append(input_features_input) + model_inputs_list.append(input_features_mask_input) + model_inputs_dict["input_features"] = input_features_input + model_inputs_dict["input_features_mask"] = input_features_mask_input + final_embeds, per_layer_inputs = self.embedding_processor( + processor_inputs + ) + + # === Decoder layers === + # The Gemma3nTextModel encapsulates the decoder loop and final norm. + # It requires `input_ids` for its internal per-layer logic. + attention_mask = keras.ops.expand_dims(padding_mask_input, axis=1) + attention_mask = keras.ops.expand_dims(attention_mask, axis=1) + sequence_output = self.language_model( + token_ids_input, + attention_mask, + final_embeds, + per_layer_inputs, + ) + super().__init__( + inputs=model_inputs_list, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self._model_inputs_dict = model_inputs_dict + self.text_vocab_size = text_vocab_size + self.text_hidden_size = text_hidden_size + self.num_hidden_layers = num_hidden_layers + self.pad_token_id = pad_token_id + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.layer_types = layer_types + self.sliding_window = sliding_window + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.altup_num_inputs = altup_num_inputs + self.laurel_rank = laurel_rank + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.rope_local_base_freq = rope_local_base_freq + self.activation_sparsity_pattern = activation_sparsity_pattern + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.altup_correct_scale = altup_correct_scale + self.num_kv_shared_layers = num_kv_shared_layers + self.vision_encoder_config = vision_encoder_config + self.vision_hidden_size = vision_hidden_size + self.vision_vocab_size = vision_vocab_size + self.vision_vocab_offset = vision_vocab_offset + self.vision_soft_tokens_per_image = vision_soft_tokens_per_image + self.image_token_id = image_token_id + self.audio_encoder_config = audio_encoder_config + self.audio_hidden_size = audio_hidden_size + self.audio_vocab_size = audio_vocab_size + self.audio_vocab_offset = audio_vocab_offset + self.audio_soft_tokens_per_image = audio_soft_tokens_per_image + self.audio_token_id = audio_token_id + self.rms_norm_eps = rms_norm_eps + + def get_config(self): + config = super().get_config() + config.update( + { + "text_vocab_size": self.text_vocab_size, + "text_hidden_size": self.text_hidden_size, + "num_hidden_layers": self.num_hidden_layers, + "pad_token_id": self.pad_token_id, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "layer_types": self.layer_types, + "sliding_window": self.sliding_window, + "rope_theta": self.rope_theta, + "max_position_embeddings": self.max_position_embeddings, + "vocab_size_per_layer_input": self.vocab_size_per_layer_input, + "hidden_size_per_layer_input": self.hidden_size_per_layer_input, + "altup_num_inputs": self.altup_num_inputs, + "laurel_rank": self.laurel_rank, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "rope_scaling": self.rope_scaling, + "rope_local_base_freq": self.rope_local_base_freq, + "activation_sparsity_pattern": self.activation_sparsity_pattern, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "altup_correct_scale": self.altup_correct_scale, + "num_kv_shared_layers": self.num_kv_shared_layers, + "vision_encoder_config": self.vision_encoder_config, + "vision_hidden_size": self.vision_hidden_size, + "vision_vocab_size": self.vision_vocab_size, + "vision_vocab_offset": self.vision_vocab_offset, + "vision_soft_tokens_per_image": self.vision_soft_tokens_per_image, # noqa: E501 + "image_token_id": self.image_token_id, + "audio_encoder_config": self.audio_encoder_config, + "audio_hidden_size": self.audio_hidden_size, + "audio_vocab_size": self.audio_vocab_size, + "audio_vocab_offset": self.audio_vocab_offset, + "audio_soft_tokens_per_image": self.audio_soft_tokens_per_image, + "audio_token_id": self.audio_token_id, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_backbone_test.py b/keras_hub/src/models/gemma3n/gemma3n_backbone_test.py new file mode 100644 index 0000000000..fdf513f334 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_backbone_test.py @@ -0,0 +1,170 @@ +from copy import deepcopy + +import numpy as np +from absl.testing import parameterized + +from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, +) +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + convert_arch_def_to_stackwise, +) +from keras_hub.src.tests.test_case import TestCase + + +class Gemma3nBackboneTest(TestCase): + def setUp(self): + self.batch_size = 1 + self.text_vocab_size = 10 + self.text_sequence_length = 8 + self.image_height = 32 + self.image_width = 32 + self.audio_sequence_length = 8 + self.audio_feature_size = 16 + + # === Vision Encoder === + vision_arch_def = [["er_r1_k3_s1_e1_c8"]] + stackwise_params = convert_arch_def_to_stackwise(vision_arch_def) + vision_encoder = MobileNetV5Backbone( + **stackwise_params, + num_features=4, + image_shape=(self.image_height, self.image_width, 3), + use_msfa=False, + ) + vision_encoder_config = vision_encoder.get_config() + + # === Audio Encoder === + audio_encoder = Gemma3nAudioEncoder( + hidden_size=4, + input_feat_size=self.audio_feature_size, + sscp_conv_channel_size=[2, 4], + sscp_conv_kernel_size=[(1, 1), (1, 1)], + sscp_conv_stride_size=[(2, 2), (2, 2)], + sscp_conv_group_norm_eps=1e-5, + conf_num_hidden_layers=1, + rms_norm_eps=1e-6, + gradient_clipping=1.0, + conf_residual_weight=0.5, + conf_num_attention_heads=1, + conf_attention_chunk_size=2, + conf_attention_context_right=1, + conf_attention_context_left=1, + conf_attention_logit_cap=50.0, + conf_conv_kernel_size=3, + conf_reduction_factor=1, + ) + + # === Multimodal === + self.multimodal_init_kwargs = { + "text_vocab_size": self.text_vocab_size, + "text_hidden_size": 4, + "num_hidden_layers": 1, + "pad_token_id": 0, + "num_attention_heads": 1, + "num_key_value_heads": 1, + "head_dim": 4, # hidden_size / num_attention_heads + "intermediate_size": [8], + "hidden_activation": "gelu_approximate", + "layer_types": ["full_attention"], + "sliding_window": 4, + "rope_theta": 10000.0, + "max_position_embeddings": self.text_sequence_length, + "vocab_size_per_layer_input": 10, + "hidden_size_per_layer_input": 2, + "altup_num_inputs": 2, + "laurel_rank": 1, + "vision_encoder_config": vision_encoder_config, + "vision_hidden_size": 8, + "audio_encoder_config": audio_encoder.get_config(), + "audio_hidden_size": 4, + } + self.multimodal_input_data = { + "token_ids": np.random.randint( + 0, + self.text_vocab_size, + size=(self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "padding_mask": np.ones( + (self.batch_size, self.text_sequence_length), dtype=bool + ), + "images": np.random.rand( + self.batch_size, 1, self.image_height, self.image_width, 3 + ).astype("float32"), + "input_features": np.random.rand( + self.batch_size, + 1, + self.audio_sequence_length, + self.audio_feature_size, + ).astype("float32"), + "input_features_mask": np.zeros( + (self.batch_size, 1, self.audio_sequence_length), dtype=bool + ), + } + + # === Text-Only === + self.text_init_kwargs = deepcopy(self.multimodal_init_kwargs) + del self.text_init_kwargs["vision_encoder_config"] + del self.text_init_kwargs["audio_encoder_config"] + del self.text_init_kwargs["vision_hidden_size"] + del self.text_init_kwargs["audio_hidden_size"] + self.text_input_data = deepcopy(self.multimodal_input_data) + del self.text_input_data["images"] + del self.text_input_data["input_features"] + del self.text_input_data["input_features_mask"] + + @parameterized.named_parameters( + ("multimodal", "multimodal"), ("text_only", "text_only") + ) + def test_backbone_basics(self, backbone_type): + if backbone_type == "multimodal": + init_kwargs = self.multimodal_init_kwargs + input_data = self.multimodal_input_data + else: + init_kwargs = self.text_init_kwargs + input_data = self.text_input_data + self.run_backbone_test( + cls=Gemma3nBackbone, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=( + self.batch_size, + self.text_sequence_length, + init_kwargs["text_hidden_size"], + ), + ) + + @parameterized.named_parameters( + ("multimodal", "multimodal"), ("text_only", "text_only") + ) + def test_saved_model(self, backbone_type): + if backbone_type == "multimodal": + init_kwargs = self.multimodal_init_kwargs + input_data = self.multimodal_input_data + else: + init_kwargs = self.text_init_kwargs + input_data = self.text_input_data + self.run_model_saving_test( + cls=Gemma3nBackbone, + init_kwargs=init_kwargs, + input_data=input_data, + ) + + @parameterized.named_parameters( + ("multimodal", "multimodal", 5450, 7), + ("text_only", "text_only", 350, 4), + ) + def test_architecture_characteristics( + self, backbone_type, num_params, num_layers + ): + if backbone_type == "multimodal": + init_kwargs = self.multimodal_init_kwargs + else: + init_kwargs = self.text_init_kwargs + model = Gemma3nBackbone(**init_kwargs) + self.assertEqual(model.count_params(), num_params) + self.assertEqual(len(model.layers), num_layers) diff --git a/keras_hub/src/models/gemma3n/gemma3n_causal_lm.py b/keras_hub/src/models/gemma3n/gemma3n_causal_lm.py new file mode 100644 index 0000000000..7b6288aab5 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_causal_lm.py @@ -0,0 +1,488 @@ +import keras +import numpy as np + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.models.gemma3n.gemma3n_causal_lm_preprocessor import ( + Gemma3nCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.Gemma3nCausalLM") +class Gemma3nCausalLM(CausalLM): + """An end-to-end multimodal Gemma3n model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + images, audio, and plain text inputs, or to autoregressively generate plain + text similar to the data used for training. Note that the model is + image-audio-text in, text out. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + preprocessor: A `keras_hub.models.Gemma3nCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + backbone: A `keras_hub.models.Gemma3nBackbone` instance. + + Examples: + ```python + import numpy as np + from keras_hub.models import Gemma3nCausalLM + + # === Text-only usage === + # Load a text-only Gemma3n model from preset. + causal_lm = Gemma3nCausalLM.from_preset("gemma3n_instruct_1b") + + # Generate text. + causal_lm.generate("What is the capital of France?", max_length=128) + + # === Vision + Text usage === + # Load a vision-text Gemma3n model from preset. + causal_lm = Gemma3nCausalLM.from_preset("gemma3n_instruct_4b") + + # Generate with image input. + image = np.ones((768, 768, 3), dtype="float32") + causal_lm.generate({ + "prompts": "Describe this image: ", + "images": image + }) + + # === Audio + Text usage === + # Load an audio-text Gemma3n model from preset. + causal_lm = Gemma3nCausalLM.from_preset("gemma3n_instruct_4b_audio") + + # Generate with audio input. + audio = np.ones((16000,), dtype="float32") + causal_lm.generate({ + "prompts": "Transcribe this audio: ", + "audios": audio + }) + + # === Vision + Audio + Text usage === + # Generate with both image and audio. + causal_lm.generate({ + "prompts": "Image: , Audio: ", + "images": image, + "audios": audio + }) + ``` + """ + + backbone_cls = Gemma3nBackbone + preprocessor_cls = Gemma3nCausalLMPreprocessor + + def __init__( + self, + preprocessor, + backbone, + **kwargs, + ): + # === Layers === + self.preprocessor = preprocessor + self.backbone = backbone + + # === Functional Model === + inputs = backbone._model_inputs_dict.copy() + if "images" in inputs: + if "vision_indices" not in inputs: + inputs["vision_indices"] = keras.Input( + shape=(None,), dtype="int32", name="vision_indices" + ) + if "vision_mask" not in inputs: + inputs["vision_mask"] = keras.Input( + shape=(None,), dtype="bool", name="vision_mask" + ) + if "input_features" in inputs: + if "audio_indices" not in inputs: + inputs["audio_indices"] = keras.Input( + shape=(None,), dtype="int32", name="audio_indices" + ) + if "audio_mask" not in inputs: + inputs["audio_mask"] = keras.Input( + shape=(None,), dtype="bool", name="audio_mask" + ) + hidden_state = backbone(inputs) + outputs = backbone.language_model.token_embedding( + hidden_state, reverse=True + ) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", + sampler="greedy", + **kwargs, + ): + super().compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, + sampler=sampler, + **kwargs, + ) + + def _normalize_generate_inputs( + self, + inputs, + ): + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + if self.preprocessor is None: + return [inputs], False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + if isinstance(inputs, dict): + inputs["prompts"], input_is_scalar = normalize(inputs["prompts"]) + # Handle unbatched image input. + if "images" in inputs and input_is_scalar: + x = inputs["images"] + if isinstance(x, np.ndarray) and len(x.shape) == 3: + inputs["images"] = [x] + elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 3: + inputs["images"] = x[tf.newaxis] + elif isinstance(x, list): + inputs["images"] = [x] + # Handle unbatched audio input. + if "audios" in inputs and input_is_scalar: + x = inputs["audios"] + if isinstance(x, np.ndarray) and len(x.shape) == 1: + inputs["audios"] = [x] + elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 1: + inputs["audios"] = x[tf.newaxis] + elif isinstance(x, list): + inputs["audios"] = [x] + if "responses" in inputs: + inputs["responses"], _ = normalize(inputs["responses"]) + else: + inputs, input_is_scalar = normalize(inputs) + + return [inputs], input_is_scalar + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + pixel_values=None, + input_features=None, + input_features_mask=None, + vision_indices=None, + audio_indices=None, + vision_mask=None, + audio_mask=None, + padding_mask=None, + cache_update_mask=None, + ): + """Forward pass of `Gemma3nCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: A dense int Tensor with shape `(batch_size, max_length)`. + cache: A dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + pixel_values: A dense float Tensor with shape + `(batch_size, num_images, height, width, channels)`. + input_features: A dense float Tensor with shape + `(batch_size, num_audios, audio_seq_len, feature_size)`. + input_features_mask: A dense bool Tensor with shape + `(batch_size, num_audios, audio_seq_len)`. + vision_indices: A dense int Tensor with shape + `(batch_size, num_vision_tokens)`. + audio_indices: A dense int Tensor with shape + `(batch_size, num_audio_tokens)`. + vision_mask: A dense bool Tensor with shape + `(batch_size, max_length)`. + audio_mask: A dense bool Tensor with shape + `(batch_size, max_length)`. + padding_mask: A dense int Tensor with shape + `(batch_size, max_length)`. + cache_update_mask: A dense bool Tensor for masking cache updates. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + # Build inputs dict for embedding processor. + processor_inputs = {"token_ids": token_ids} + if pixel_values is not None: + processor_inputs["pixel_values"] = pixel_values + processor_inputs["vision_indices"] = vision_indices + processor_inputs["vision_mask"] = vision_mask + if input_features is not None: + processor_inputs["input_features"] = input_features + processor_inputs["input_features_mask"] = input_features_mask + processor_inputs["audio_indices"] = audio_indices + processor_inputs["audio_mask"] = audio_mask + # Get embeddings and per-layer inputs. + inputs_embeds, per_layer_inputs = self.backbone.embedding_processor( + processor_inputs + ) + # Prepare attention mask for caching. + batch_size = keras.ops.shape(token_ids)[0] + max_length = keras.ops.shape(token_ids)[1] + # Create causal attention mask. + if padding_mask is None: + padding_mask = keras.ops.ones( + (batch_size, max_length), dtype="bool" + ) + attention_mask = keras.ops.cast(padding_mask, dtype="bool") + attention_mask = keras.ops.expand_dims(attention_mask, axis=1) + attention_mask = keras.ops.expand_dims(attention_mask, axis=1) + # Each decoder layer has a cache; we update them separately. + hidden_states, new_cache = self.backbone.language_model( + input_ids=token_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + cache=cache, + cache_update_index=cache_update_index, + cache_update_mask=cache_update_mask, + ) + logits = self.backbone.language_model.token_embedding( + hidden_states, reverse=True + ) + return logits, hidden_states, new_cache + + def _build_cache( + self, + token_ids, + pixel_values, + input_features, + input_features_mask, + vision_indices, + audio_indices, + vision_mask, + audio_mask, + padding_mask, + ): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = keras.ops.shape(token_ids)[0] + max_length = keras.ops.shape(token_ids)[1] + num_layers = self.backbone.num_hidden_layers + num_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [batch_size, num_layers, 2, num_heads, max_length, head_dim] + cache = keras.ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache( + token_ids=token_ids, + cache=cache, + cache_update_index=0, + pixel_values=pixel_values, + input_features=input_features, + input_features_mask=input_features_mask, + vision_indices=vision_indices, + audio_indices=audio_indices, + vision_mask=vision_mask, + audio_mask=audio_mask, + padding_mask=padding_mask, + cache_update_mask=None, + ) + return hidden_states, cache + + def generate_step(self, inputs, stop_token_ids=[106]): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys for token_ids, padding_mask, and + optionally images, audios, vision_mask, audio_mask, etc. + + Args: + inputs: A dictionary with keys for the model inputs including + `"token_ids"`, `"padding_mask"`, and optionally `"images"`, + `"audios"`, `"input_features"`, `"input_features_mask"`, + `"vision_mask"`, `"audio_mask"`, `"vision_indices"`, + `"audio_indices"`. + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids = inputs["token_ids"] + padding_mask = inputs["padding_mask"] + # Extract multimodal inputs. + images = inputs.get("images", None) + pixel_values = images + input_features = inputs.get("input_features", None) + input_features_mask = inputs.get("input_features_mask", None) + vision_indices = inputs.get("vision_indices", None) + audio_indices = inputs.get("audio_indices", None) + vision_mask = inputs.get("vision_mask", None) + audio_mask = inputs.get("audio_mask", None) + audios = inputs.get("audios", None) + # Handle unbatched inputs by adding batch dimension. + if pixel_values is not None and len(keras.ops.shape(pixel_values)) == 4: + pixel_values = keras.ops.expand_dims(pixel_values, axis=0) + if audios is not None and len(keras.ops.shape(audios)) == 2: + audios = keras.ops.expand_dims(audios, axis=0) + if vision_mask is not None and len(keras.ops.shape(vision_mask)) == 1: + vision_mask = keras.ops.expand_dims(vision_mask, axis=0) + if ( + vision_indices is not None + and len(keras.ops.shape(vision_indices)) == 1 + ): + vision_indices = keras.ops.expand_dims(vision_indices, axis=0) + if ( + input_features is not None + and len(keras.ops.shape(input_features)) == 2 + ): + input_features = keras.ops.expand_dims(input_features, axis=0) + if ( + input_features_mask is not None + and len(keras.ops.shape(input_features_mask)) == 1 + ): + input_features_mask = keras.ops.expand_dims( + input_features_mask, axis=0 + ) + if audio_mask is not None and len(keras.ops.shape(audio_mask)) == 1: + audio_mask = keras.ops.expand_dims(audio_mask, axis=0) + if ( + audio_indices is not None + and len(keras.ops.shape(audio_indices)) == 1 + ): + audio_indices = keras.ops.expand_dims(audio_indices, axis=0) + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache( + token_ids, + pixel_values, + input_features, + input_features_mask, + vision_indices, + audio_indices, + vision_mask, + audio_mask, + padding_mask, + ) + # Compute the lengths of all user inputted tokens ids. + row_lengths = keras.ops.sum( + keras.ops.cast(padding_mask, "int32"), axis=-1 + ) + # Start at the first index that has no user inputted id. + index = keras.ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = keras.ops.shape(prompt)[0] + prompt = keras.ops.slice(prompt, [0, index - 1], [batch_size, 1]) + sliced_cache_update_mask = keras.ops.slice( + ~padding_mask, [0, index - 1], [batch_size, 1] + ) + logits, hidden_states, cache = self.call_with_cache( + token_ids=prompt, + cache=cache, + cache_update_index=cache_update_index, + cache_update_mask=sliced_cache_update_mask, + ) + return ( + keras.ops.squeeze(logits, axis=1), + keras.ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, keras.ops.logical_not(padding_mask) + ) + end_locations = keras.ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = keras.ops.cast( + keras.ops.cumsum(end_locations, axis=-1), "int32" + ) + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = keras.ops.logical_not( + keras.ops.cast(overflow, "bool") + ) + else: + # Without early stopping, all locations will have been updated. + padding_mask = keras.ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate( + self, + inputs, + max_length=None, + stop_token_ids="auto", + strip_prompt=False, + ): + # If `auto`, add end_of_turn as a stop token too. + if self.preprocessor is None and stop_token_ids == "auto": + raise ValueError( + "A `preprocessor` must be attached to the model if " + '`stop_token_ids="auto"`. Currently `preprocessor=None`. To ' + "call `generate()` with preprocessing detached, either pass " + "`stop_token_ids=None` to always generate until `max_length` " + "or pass a tuple of token ids that should terminate generation " + "as `stop_token_ids`." + ) + elif stop_token_ids == "auto": + stop_token_ids = [ + self.preprocessor.tokenizer.end_token_id, + ] + # Add end_of_turn token if available. + end_of_turn_id = self.preprocessor.tokenizer.token_to_id( + "" + ) + if end_of_turn_id is not None: + stop_token_ids.append(end_of_turn_id) + return super().generate( + inputs, + max_length=max_length, + stop_token_ids=stop_token_ids, + strip_prompt=strip_prompt, + ) diff --git a/keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor.py b/keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor.py new file mode 100644 index 0000000000..95a5b7fd0c --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor.py @@ -0,0 +1,998 @@ +import keras +import numpy as np +import tensorflow as tf + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.models.gemma3n.gemma3n_image_converter import ( + Gemma3nImageConverter, +) +from keras_hub.src.models.gemma3n.gemma3n_tokenizer import Gemma3nTokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function +from keras_hub.src.utils.tensor_utils import strip_to_ragged + + +@keras_hub_export("keras_hub.models.Gemma3nCausalLMPreprocessor") +class Gemma3nCausalLMPreprocessor(CausalLMPreprocessor): + """Gemma3n Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.Gemma3nCausalLM`. It can be configured in three ways: + text-only, text + vision, and text + vision + audio, based on whether the + passed values of `image_converter` and `audio_converter` are None. For + text-only, it takes in batches of strings. For text + vision, it takes in + batches of images and strings. For text + vision + audio, it takes in + batches of images, audio, and strings. It returns outputs in a + `(x, y, sample_weight)` format, where the `y` label is the next token id in + the `x` sequence. `sample_weight` is 0 for "prompt" tokens, and 1 for + "response" tokens, so that the loss is computed only on the "response" + tokens. + + For the text + vision case, this layer replaces instances of + `` token in the prompt with `num_vision_tokens_per_image` + placeholder tokens. It also returns indices of where these vision tokens + are present so that in the model, image embeddings can be placed in the + right position in the sequence of text embeddings. + + For the text + audio case, this layer replaces instances of + `` token in the prompt with `num_audio_tokens_per_audio` + placeholder tokens. It also returns indices of where these audio tokens + are present so that in the model, audio embeddings can be placed in the + right position in the sequence of text embeddings. + + Note that if `max_images_per_prompt` is 2, you can pass either 0, 1, 2 + images per sample. The value 0 corresponds to text-only input. Similarly, + if `max_audios_per_prompt` is 2, you can pass either 0, 1, 2 audio clips + per sample. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.Gemma3nCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.Gemma3nTokenizer` instance. + image_converter: A `keras_hub.layers.ImageConverter` instance. Defaults + to `None`. + audio_converter: A `keras_hub.layers.AudioConverter` instance. Defaults + to `None`. + sequence_length: The length of the packed inputs. Defaults to 1024. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Defaults to `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Defaults to `True`. + max_images_per_prompt: int. Permissible number of images per sample in + the batch. Defaults to 2. + num_vision_tokens_per_image: int. Number of vision placeholder tokens + per image. Defaults to 256. + max_audios_per_prompt: int. Permissible number of audio clips per sample + in the batch. Defaults to 2. + num_audio_tokens_per_audio: int. Number of audio placeholder tokens + per audio clip. Defaults to 188. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # === Language === + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.Gemma3nCausalLMPreprocessor.from_preset( + "gemma3n_instruct_1b" + ) + + # Unbatched inputs. + preprocessor( + { + "prompts": "What is the capital of India?", + "responses": "New Delhi", + } + ) + + # Batched inputs. + preprocessor( + { + "prompts": [ + "What is the capital of India?", + "What is the capital of Spain?" + ], + "responses": ["New Delhi", "Madrid"], + } + ) + + # Apply preprocessing to a `tf.data.Dataset`. + features = { + "prompts": [ + "What is the capital of India?", + "What is the capital of Spain?" + ], + "responses": ["New Delhi", "Madrid"], + } + + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation (no end token). + preprocessor.generate_preprocess(["The quick brown fox jumped."]) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'token_ids': np.array([[2, 818, 3823, 8864, 37423, 32694, 236761, 0]]), + 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), + }) + + # === Vision and Language === + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.Gemma3nCausalLMPreprocessor.from_preset( + "gemma3n_instruct_4b" + ) + + # Text-only inputs (unbatched). + preprocessor( + { + "prompts": "What is the capital of India?", + "responses": "New Delhi", + } + ) + + # Text-only inputs (batched). + preprocessor( + { + "prompts": [ + "What is the capital of India?", + "What is the capital of Spain?" + ], + "responses": ["New Delhi", "Madrid"], + } + ) + + # Unbatched inputs, with one image. + preprocessor( + { + "prompts": "this is a lily ", + "responses": "pristine!", + "images": np.ones((768, 768, 3), dtype="float32") + } + ) + + # Unbatched inputs, with two images. + preprocessor( + { + "prompts": "lily: , sunflower: ", + "responses": "pristine!", + "images": [ + np.ones((768, 768, 3), dtype="float32"), + np.ones((768, 768, 3), dtype="float32") + ], + } + ) + + # Batched inputs, one image per prompt. + preprocessor( + { + "prompts": [ + "this is a lily: ", + "this is a sunflower: " + ], + "responses": ["pristine!", "radiant!"], + "images": [ + np.ones((768, 768, 3), dtype="float32"), + np.ones((768, 768, 3), dtype="float32") + ] + } + ) + + # === Audio and Language === + # Unbatched inputs, with one audio clip. + preprocessor( + { + "prompts": "transcribe this: ", + "responses": "hello world", + "audios": np.ones((16000,), dtype="float32") + } + ) + + # === Vision, Audio and Language === + # Unbatched inputs, with one image and one audio. + preprocessor( + { + "prompts": "image: , audio: ", + "responses": "multimodal!", + "images": np.ones((768, 768, 3), dtype="float32"), + "audios": np.ones((16000,), dtype="float32") + } + ) + ``` + """ + + backbone_cls = Gemma3nBackbone + tokenizer_cls = Gemma3nTokenizer + image_converter_cls = Gemma3nImageConverter + + def __init__( + self, + tokenizer, + image_converter=None, + audio_converter=None, + sequence_length=1024, + add_start_token=True, + add_end_token=True, + max_images_per_prompt=2, + num_vision_tokens_per_image=256, + max_audios_per_prompt=2, + num_audio_tokens_per_audio=188, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + add_start_token=add_start_token, + add_end_token=add_end_token, + **kwargs, + ) + # Validate sequence_length for multimodal inputs. + total_multimodal_tokens = ( + max_images_per_prompt * num_vision_tokens_per_image + + max_audios_per_prompt * num_audio_tokens_per_audio + ) + if ( + image_converter is not None or audio_converter is not None + ) and sequence_length <= total_multimodal_tokens: + raise ValueError( + "`sequence_length` should be greater than " + "`max_images_per_prompt * num_vision_tokens_per_image + " + "max_audios_per_prompt * num_audio_tokens_per_audio`. " + f"Received: `sequence_length` = {sequence_length}, " + f"`max_images_per_prompt` = {max_images_per_prompt}, " + f"`num_vision_tokens_per_image` = {num_vision_tokens_per_image}, " # noqa: E501 + f"`max_audios_per_prompt` = {max_audios_per_prompt}, " + f"`num_audio_tokens_per_audio` = {num_audio_tokens_per_audio}" + ) + self.image_converter = image_converter + self.audio_converter = audio_converter + self.max_images_per_prompt = max_images_per_prompt + self.num_vision_tokens_per_image = num_vision_tokens_per_image + self.max_audios_per_prompt = max_audios_per_prompt + self.num_audio_tokens_per_audio = num_audio_tokens_per_audio + # Determine model type. + self.text_only_model = ( + self.image_converter is None and self.audio_converter is None + ) + # Special tokens for images. + self.image_placeholder = self.tokenizer.image_placeholder + self.start_of_image_token = self.tokenizer.start_of_image_token + self.end_of_image_token = self.tokenizer.end_of_image_token + # Special tokens for audio. + self.audio_placeholder = self.tokenizer.audio_placeholder + self.start_of_audio_token = self.tokenizer.start_of_audio_token + self.end_of_audio_token = self.tokenizer.end_of_audio_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sep_value=[], + sequence_length=self.sequence_length, + ) + self.built = True + + def _get_vision_indices(self, vision_mask): + """Computes indices given vision mask, and pads with 0. + + If `vision_mask` is + + ``` + [ + [False, True, True], [False, True, False], [False, False, False] + ] + ``` + + , then the output will be: + + ``` + [ + [1, 2, 0], [1, 0, 0], [0, 0, 0] + ] + ``` + """ + batch_size, sequence_length = vision_mask.shape + vision_mask_flattened = tf.reshape(vision_mask, [-1]) + vision_indices = tf.where(vision_mask_flattened)[..., 0] + vision_indices = tf.cast(vision_indices, dtype=tf.int32) + row_lengths = tf.math.reduce_sum( + tf.cast(vision_mask, dtype=vision_indices.dtype), axis=1 + ) + batched_vision_indices = tf.RaggedTensor.from_row_lengths( + values=vision_indices, + row_lengths=row_lengths, + ) + to_subtract = tf.math.scalar_mul( + scalar=tf.cast(sequence_length, dtype=tf.int32), + x=tf.range( + start=0, + limit=tf.shape(vision_mask)[0], + dtype=tf.int32, + ), + ) + # All indices should be independent of other samples in the batch. + batched_vision_indices = tf.math.subtract( + batched_vision_indices, + tf.expand_dims(to_subtract, axis=-1), + ) + # Pad the indices. + batched_vision_indices = batched_vision_indices.to_tensor( + shape=[ + batch_size, + self.max_images_per_prompt * self.num_vision_tokens_per_image, + ], + default_value=0, + ) + return batched_vision_indices + + def _get_audio_indices(self, audio_mask): + """Computes indices given audio mask, and pads with 0. + + Similar to _get_vision_indices but for audio tokens. + """ + batch_size, sequence_length = audio_mask.shape + audio_mask_flattened = tf.reshape(audio_mask, [-1]) + audio_indices = tf.where(audio_mask_flattened)[..., 0] + audio_indices = tf.cast(audio_indices, dtype=tf.int32) + row_lengths = tf.math.reduce_sum( + tf.cast(audio_mask, dtype=audio_indices.dtype), axis=1 + ) + batched_audio_indices = tf.RaggedTensor.from_row_lengths( + values=audio_indices, + row_lengths=row_lengths, + ) + to_subtract = tf.math.scalar_mul( + scalar=tf.cast(sequence_length, dtype=tf.int32), + x=tf.range( + start=0, + limit=tf.shape(audio_mask)[0], + dtype=tf.int32, + ), + ) + # All indices should be independent of other samples in the batch. + batched_audio_indices = tf.math.subtract( + batched_audio_indices, + tf.expand_dims(to_subtract, axis=-1), + ) + # Pad the indices. + batched_audio_indices = batched_audio_indices.to_tensor( + shape=[ + batch_size, + self.max_audios_per_prompt * self.num_audio_tokens_per_audio, + ], + default_value=0, + ) + return batched_audio_indices + + def _format_output( + self, + images, + audios, + input_features, + input_features_mask, + token_ids, + vision_mask, + audio_mask, + response_mask, + padding_mask, + return_labels=False, + text_only_input=False, + batched=False, + ): + if return_labels: + # Target `y` will be the next token. + y = token_ids[..., 1:] + # Only compute the loss for labels in the response. + sample_weight = response_mask[..., 1:] + # The last token does not have a next token. So, remove it. + token_ids = token_ids[..., :-1] + vision_mask = vision_mask[..., :-1] + audio_mask = audio_mask[..., :-1] + response_mask = response_mask[..., :-1] + padding_mask = padding_mask[..., :-1] + x = { + "token_ids": token_ids + if batched + else tf.squeeze(token_ids, axis=0), + "padding_mask": padding_mask + if batched + else tf.squeeze(padding_mask, axis=0), + } + if self.image_converter is not None: + vision_indices = self._get_vision_indices(vision_mask=vision_mask) + x["images"] = images if batched else tf.squeeze(images, axis=0) + x["vision_indices"] = ( + vision_indices + if batched + else tf.squeeze(vision_indices, axis=0) + ) + x["vision_mask"] = ( + vision_mask if batched else tf.squeeze(vision_mask, axis=0) + ) + if self.audio_converter is not None: + audio_indices = self._get_audio_indices(audio_mask=audio_mask) + x["input_features"] = ( + input_features + if batched + else tf.squeeze(input_features, axis=0) + ) + x["input_features_mask"] = ( + input_features_mask + if batched + else tf.squeeze(input_features_mask, axis=0) + ) + x["audio_indices"] = ( + audio_indices if batched else tf.squeeze(audio_indices, axis=0) + ) + x["audio_mask"] = ( + audio_mask if batched else tf.squeeze(audio_mask, axis=0) + ) + # For generation only. + if not return_labels: + x["audios"] = audios if batched else tf.squeeze(audios, axis=0) + if return_labels: + if not batched: + y = tf.squeeze(y, axis=0) + sample_weight = tf.squeeze(sample_weight, 0) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + else: + return x + + def _preprocess_images(self, images, batched): + desired_height = self.image_converter.image_size[0] + desired_width = self.image_converter.image_size[1] + # Images can be lists/ragged tensors. We need to pad them/truncate them. + if isinstance(images, (list, np.ndarray)): + images = tf.ragged.constant(images) + elif isinstance(images, tf.RaggedTensor): + pass + elif isinstance(images, tf.Tensor): + images = tf.RaggedTensor.from_tensor(images) + else: + # Attempt to convert anyway. + try: + images = tf.RaggedTensor.from_tensor(images) + except: # noqa: E722 + raise ValueError( + "`images` should be a list, ragged tensor, dense tensor. " + f"Received: `type(images)` = {type(images)}" + ) + if not batched: + images = tf.expand_dims(images, axis=0) + # If the input is a list of images, instead of list of lists of images. + if len(images.shape) == 4: + images = tf.expand_dims(images, axis=1) + # Convert to dense tensor. + images = images.to_tensor( + shape=[None, self.max_images_per_prompt, None, None, 3], + default_value=0, + ) + # Resize, rescale, etc. the images. + original_images_shape = tf.shape(images) + # Before passing through image converter, we need to collapse the + # first two dimensions. + images = tf.reshape( + images, + [ + -1, + original_images_shape[-3], + original_images_shape[-2], + original_images_shape[-1], + ], + ) + images = self.image_converter(images) + if keras.config.backend() == "torch" and not isinstance( + images, tf.Tensor + ): + images = images.cpu() + # Recover the rank. + images = tf.reshape( + images, + [ + original_images_shape[0], + self.max_images_per_prompt, + desired_height, + desired_width, + original_images_shape[-1], + ], + ) + return images + + def _preprocess_audios(self, audios, batched): + if hasattr(audios, "cpu") and hasattr(audios, "numpy"): + audios = audios.cpu().numpy() + # Audios can be lists/ragged tensors. We need to pad them/truncate them. + if isinstance(audios, (list, np.ndarray)): + if isinstance(audios, np.ndarray) and audios.ndim == 1: + audios = [audios] + audios = tf.ragged.constant(audios, dtype=tf.float32) + elif isinstance(audios, tf.RaggedTensor): + pass + elif isinstance(audios, tf.Tensor): + if len(audios.shape) > 1: + audios = tf.RaggedTensor.from_tensor(audios) + else: + audios = tf.ragged.constant([audios.numpy()], dtype=tf.float32) + else: + # Attempt to convert anyway. + try: + audios = tf.convert_to_tensor(audios, dtype=tf.float32) + if len(audios.shape) == 1: + audios = tf.ragged.constant( + [audios.numpy()], dtype=tf.float32 + ) + else: + audios = tf.RaggedTensor.from_tensor(audios) + except: # noqa: E722 + raise ValueError( + "`audios` should be a list, ragged tensor, dense tensor. " + f"Received: `type(audios)` = {type(audios)}" + ) + if not batched: + audios = tf.expand_dims(audios, axis=0) + # If the input is a list of audio arrays, instead of list of lists. + if len(audios.shape) == 2: + audios = tf.expand_dims(audios, axis=1) + # Convert to dense tensor. + audios = audios.to_tensor( + shape=[None, self.max_audios_per_prompt, None], + default_value=0, + ) + # Process through audio converter. + original_audios_shape = tf.shape(audios) + batch_size = original_audios_shape[0] + num_audios = original_audios_shape[1] + # Flatten batch and audio dimensions for processing. + audios_flat = tf.reshape(audios, [-1, original_audios_shape[-1]]) + # Process audio through converter. + input_features, input_features_mask = self.audio_converter( + audios_flat, + padding="longest", + ) + # Reshape back to [batch_size, max_audios_per_prompt, ...]. + feature_shape = tf.shape(input_features) + input_features = tf.reshape( + input_features, + [batch_size, num_audios, feature_shape[1], feature_shape[2]], + ) + mask_shape = tf.shape(input_features_mask) + input_features_mask = tf.reshape( + input_features_mask, + [batch_size, num_audios, mask_shape[1]], + ) + return audios, input_features, input_features_mask + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + sequence_length = sequence_length or self.sequence_length + + # === Input extraction and validation === + # Extract text part of the input. + prompts, responses = x["prompts"], x["responses"] + tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))]) + # Find out if the input is batched/not batched. + batched = True + if isinstance(prompts, str): + batched = False + prompts = [prompts] + responses = [responses] + if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0: + batched = False + prompts = tf.expand_dims(prompts, axis=0) + responses = tf.expand_dims(responses, axis=0) + # Extract images and audios from the input. + images = x.get("images", None) + audios = x.get("audios", None) + # Validate multimodal inputs. + if self.text_only_model and (images is not None or audios is not None): + raise ValueError( + "The initialized preprocessor/model is text-only, but " + "`images` or `audios` is not `None`." + ) + # Add image placeholder tokens. + if not self.text_only_model and self.image_converter is not None: + prompts = tf.strings.regex_replace( + prompts, + self.start_of_image_token, + f"\n\n{self.start_of_image_token}" + + self.image_placeholder * self.num_vision_tokens_per_image + + f"{self.end_of_image_token}\n\n", + ) + # Add audio placeholder tokens. + if not self.text_only_model and self.audio_converter is not None: + prompts = tf.strings.regex_replace( + prompts, + self.start_of_audio_token, + f"\n\n{self.start_of_audio_token}" + + self.audio_placeholder * self.num_audio_tokens_per_audio + + f"{self.end_of_audio_token}\n\n", + ) + + # === Tokenization, padding, etc. === + # Tokenise the inputs. + prompts = self.tokenizer(prompts) + responses = self.tokenizer(responses) + # Padding. + token_ids, segment_ids = self.packer( + (prompts, responses), + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + response_mask = segment_ids == 1 + padding_mask = token_ids != self.tokenizer.pad_token_id + + # === Text Model === + if self.text_only_model: + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y = token_ids[..., 1:] + # Only compute the loss for labels in the response. + sample_weight = response_mask[..., 1:] + # Squeeze if not batched. + if not batched: + x["token_ids"] = tf.squeeze(x["token_ids"], axis=0) + x["padding_mask"] = tf.squeeze(x["padding_mask"], axis=0) + y = tf.squeeze(y, axis=0) + sample_weight = tf.squeeze(sample_weight, axis=0) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + # === Multimodal processing === + batch_size = tf.shape(prompts)[0] + desired_height = ( + self.image_converter.image_size[0] if self.image_converter else 0 + ) + desired_width = ( + self.image_converter.image_size[1] if self.image_converter else 0 + ) + # Process vision. + if images is None and self.image_converter is not None: + images = tf.ones( + shape=[ + batch_size, + 0, + desired_height, + desired_width, + 3, + ], + dtype="float32", + ) + vision_mask = tf.zeros_like(token_ids, dtype=bool) + elif images is not None and self.image_converter is not None: + images = self._preprocess_images(images=images, batched=batched) + vision_mask = token_ids == self.tokenizer.image_placeholder_id + else: + # No image converter. + images = tf.ones( + shape=[ + batch_size, + 0, + 0, + 0, + 3, + ], + dtype="float32", + ) + vision_mask = tf.zeros_like(token_ids, dtype=bool) + # Process audio. + if audios is None and self.audio_converter is not None: + audios = tf.ones( + shape=[batch_size, 0, 0], + dtype="float32", + ) + input_features = tf.ones( + shape=[batch_size, 0, 0, self.audio_converter.feature_size], + dtype="float32", + ) + input_features_mask = tf.ones( + shape=[batch_size, 0, 0], + dtype="bool", + ) + audio_mask = tf.zeros_like(token_ids, dtype=bool) + elif audios is not None and self.audio_converter is not None: + audios, input_features, input_features_mask = ( + self._preprocess_audios(audios=audios, batched=batched) + ) + audio_mask = token_ids == self.tokenizer.audio_placeholder_id + else: + # No audio converter. + feature_size = ( + self.audio_converter.feature_size + if self.audio_converter is not None + else 128 + ) + audios = tf.ones( + shape=[batch_size, 0, 0], + dtype="float32", + ) + input_features = tf.ones( + shape=[batch_size, 0, 0, feature_size], + dtype="float32", + ) + input_features_mask = tf.ones( + shape=[batch_size, 0, 0], + dtype="bool", + ) + audio_mask = tf.zeros_like(token_ids, dtype=bool) + + return self._format_output( + images=images, + audios=audios, + input_features=input_features, + input_features_mask=input_features_mask, + token_ids=token_ids, + vision_mask=vision_mask, + audio_mask=audio_mask, + response_mask=response_mask, + padding_mask=padding_mask, + return_labels=True, + text_only_input=(images is None and audios is None), + batched=batched, + ) + + @preprocessing_function + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + # Extract inputs. + if isinstance(x, dict): + images = x.get("images", None) + audios = x.get("audios", None) + responses = x.get("responses", None) + prompts = x["prompts"] + else: + images = None + audios = None + responses = None + prompts = x + # Find out if the input is batched/not batched. + batched = True + if isinstance(prompts, str): + batched = False + prompts = [prompts] + if responses is not None: + responses = [responses] + if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0: + batched = False + prompts = tf.expand_dims(prompts, axis=0) + if responses is not None: + responses = tf.expand_dims(responses, axis=0) + # Validate multimodal inputs. + if self.text_only_model and (images is not None or audios is not None): + raise ValueError( + "The initialized preprocessor/model is text-only, but " + "`images` or `audios` is not `None`." + ) + # Add image placeholder tokens. + if not self.text_only_model and self.image_converter is not None: + prompts = tf.strings.regex_replace( + prompts, + self.start_of_image_token, + f"\n\n{self.start_of_image_token}" + + self.image_placeholder * self.num_vision_tokens_per_image + + f"{self.end_of_image_token}\n\n", + ) + # Add audio placeholder tokens. + if not self.text_only_model and self.audio_converter is not None: + prompts = tf.strings.regex_replace( + prompts, + self.start_of_audio_token, + f"\n\n{self.start_of_audio_token}" + + self.audio_placeholder * self.num_audio_tokens_per_audio + + f"{self.end_of_audio_token}\n\n", + ) + + # === Tokenization, padding, etc. === + prompts = self.tokenizer(prompts) + if responses is not None: + responses = self.tokenizer(responses) + segments = (prompts, responses) + else: + segments = (prompts,) + # Padding. + token_ids, segment_ids = self.packer( + segments, + sequence_length=sequence_length, + add_end_value=False, + ) + response_mask = segment_ids == 1 + padding_mask = token_ids != self.tokenizer.pad_token_id + + # === Text Model === + if self.text_only_model: + return { + "token_ids": ( + token_ids if batched else tf.squeeze(token_ids, axis=0) + ), + "padding_mask": ( + padding_mask + if batched + else tf.squeeze(padding_mask, axis=0) + ), + } + + # === Multimodal processing === + batch_size = tf.shape(prompts)[0] + desired_height = ( + self.image_converter.image_size[0] if self.image_converter else 0 + ) + desired_width = ( + self.image_converter.image_size[1] if self.image_converter else 0 + ) + # Process vision. + if images is None and self.image_converter is not None: + images = tf.ones( + shape=[ + batch_size, + 0, + desired_height, + desired_width, + 3, + ], + dtype="float32", + ) + vision_mask = tf.zeros_like(token_ids, dtype=bool) + elif images is not None and self.image_converter is not None: + images = self._preprocess_images(images=images, batched=batched) + vision_mask = token_ids == self.tokenizer.image_placeholder_id + else: + # No image converter. + images = tf.ones( + shape=[ + batch_size, + 0, + 0, + 0, + 3, + ], + dtype="float32", + ) + vision_mask = tf.zeros_like(token_ids, dtype=bool) + # Process audio. + if audios is None and self.audio_converter is not None: + audios = tf.ones( + shape=[batch_size, 0, 0], + dtype="float32", + ) + input_features = tf.ones( + shape=[batch_size, 0, 0, self.audio_converter.feature_size], + dtype="float32", + ) + input_features_mask = tf.ones( + shape=[batch_size, 0, 0], + dtype="bool", + ) + audio_mask = tf.zeros_like(token_ids, dtype=bool) + elif audios is not None and self.audio_converter is not None: + audios, input_features, input_features_mask = ( + self._preprocess_audios(audios=audios, batched=batched) + ) + audio_mask = token_ids == self.tokenizer.audio_placeholder_id + else: + # No audio converter. + feature_size = ( + self.audio_converter.feature_size + if self.audio_converter is not None + else 128 + ) + audios = tf.ones( + shape=[batch_size, 0, 0], + dtype="float32", + ) + input_features = tf.ones( + shape=[batch_size, 0, 0, feature_size], + dtype="float32", + ) + input_features_mask = tf.ones( + shape=[batch_size, 0, 0], + dtype="bool", + ) + audio_mask = tf.zeros_like(token_ids, dtype=bool) + + return self._format_output( + images=images, + audios=audios, + input_features=input_features, + input_features_mask=input_features_mask, + token_ids=token_ids, + vision_mask=vision_mask, + audio_mask=audio_mask, + response_mask=response_mask, + padding_mask=padding_mask, + return_labels=False, + text_only_input=(images is None and audios is None), + batched=batched, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_vision_tokens_per_image": self.num_vision_tokens_per_image, + "max_images_per_prompt": self.max_images_per_prompt, + "num_audio_tokens_per_audio": self.num_audio_tokens_per_audio, + "max_audios_per_prompt": self.max_audios_per_prompt, + } + ) + return config + + @preprocessing_function + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + ids_to_strip = self.tokenizer.special_token_ids + if self.tokenizer.start_of_image_token_id in ids_to_strip: + ids_to_strip.remove(self.tokenizer.start_of_image_token_id) + if self.tokenizer.start_of_audio_token_id in ids_to_strip: + ids_to_strip.remove(self.tokenizer.start_of_audio_token_id) + token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) + return self.tokenizer.detokenize(token_ids) + + @property + def max_images_per_prompt(self): + return self._max_images_per_prompt + + @max_images_per_prompt.setter + def max_images_per_prompt(self, value): + self._max_images_per_prompt = value + + @property + def max_audios_per_prompt(self): + return self._max_audios_per_prompt + + @max_audios_per_prompt.setter + def max_audios_per_prompt(self, value): + self._max_audios_per_prompt = value diff --git a/keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor_test.py b/keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..2ba99fe9cf --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor_test.py @@ -0,0 +1,312 @@ +import numpy as np + +from keras_hub.src.models.gemma3n.gemma3n_audio_converter import ( + Gemma3nAudioConverter, +) +from keras_hub.src.models.gemma3n.gemma3n_causal_lm_preprocessor import ( + Gemma3nCausalLMPreprocessor, +) +from keras_hub.src.models.gemma3n.gemma3n_image_converter import ( + Gemma3nImageConverter, +) +from keras_hub.src.tests.mocks.mock_gemma3n_tokenizer import ( + MockGemma3nTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class Gemma3nCausalLMPreprocessorTest(TestCase): + def setUp(self): + # Easier to use a mock here, instead of trying to figure out why + # SentencePiece cannot tokenize and detokenize special tokens + # properly. + self.tokenizer = MockGemma3nTokenizer() + + # === Text Preprocessor === + self.init_text_kwargs = { + "tokenizer": self.tokenizer, + "image_converter": None, + "audio_converter": None, + "sequence_length": 8, + "max_images_per_prompt": 0, + "num_vision_tokens_per_image": 0, + "max_audios_per_prompt": 0, + "num_audio_tokens_per_audio": 0, + } + self.text_preprocessor = Gemma3nCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=None, + audio_converter=None, + sequence_length=100, + max_images_per_prompt=0, + num_vision_tokens_per_image=0, + max_audios_per_prompt=0, + num_audio_tokens_per_audio=0, + ) + + # === Text + Image Preprocessor === + self.image_converter = Gemma3nImageConverter( + image_size=(4, 4), + ) + self.init_vision_kwargs = { + "tokenizer": self.tokenizer, + "image_converter": self.image_converter, + "audio_converter": None, + "sequence_length": 20, + "max_images_per_prompt": 2, + "num_vision_tokens_per_image": 5, + "max_audios_per_prompt": 0, + "num_audio_tokens_per_audio": 0, + } + + # === Text + Audio Preprocessor === + self.audio_converter = Gemma3nAudioConverter( + feature_size=16, + sampling_rate=16000, + ) + self.init_audio_kwargs = { + "tokenizer": self.tokenizer, + "image_converter": None, + "audio_converter": self.audio_converter, + "sequence_length": 20, + "max_images_per_prompt": 0, + "num_vision_tokens_per_image": 0, + "max_audios_per_prompt": 2, + "num_audio_tokens_per_audio": 3, + } + + # === Text + Image + Audio Preprocessor === + self.init_multimodal_kwargs = { + "tokenizer": self.tokenizer, + "image_converter": self.image_converter, + "audio_converter": self.audio_converter, + "sequence_length": 30, + "max_images_per_prompt": 2, + "num_vision_tokens_per_image": 5, + "max_audios_per_prompt": 2, + "num_audio_tokens_per_audio": 3, + } + + def test_text_preprocessor_basics(self): + input_data = { + "prompts": ["the quick brown fox"], + "responses": ["round"], + } + self.run_preprocessing_layer_test( + cls=Gemma3nCausalLMPreprocessor, + init_kwargs=self.init_text_kwargs, + input_data=input_data, + expected_output=( + { + "token_ids": [[1, 9, 14, 10, 12, 15, 2, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [[9, 14, 10, 12, 15, 2, 0, 0]], # Labels shifted. + [[0, 0, 0, 0, 1, 1, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_vision_preprocessor_basics(self): + input_data = { + "prompts": ["the quick brown fox "], + "responses": ["round"], + "images": [[np.ones((8, 8, 3))]], + } + output = self.run_preprocessing_layer_test( + cls=Gemma3nCausalLMPreprocessor, + init_kwargs=self.init_vision_kwargs, + input_data=input_data, + return_output=True, + ) + expected_x = { + "token_ids": [ + [1, 9, 14, 10, 12, 16, 4] + [8] * 5 + [5, 16, 15, 2] + [0] * 4 + ], + "padding_mask": [[1] * 16 + [0] * 4], + } + expected_y = [ + [9, 14, 10, 12, 16, 4] + [8] * 5 + [5, 16, 15, 2] + [0] * 5 + ] # Labels shifted. + expected_sw = [ + [0] * 13 + [1] * 2 + [0] * 5 + ] # Zero out unlabeled examples. + # Check shape for images. + self.assertIn("images", output[0]) + self.assertAllEqual(output[0]["images"].shape, [1, 2, 4, 4, 3]) + self.assertNotIn("audios", output[0]) + self.assertNotIn("input_features", output[0]) + self.assertNotIn("input_features_mask", output[0]) + self.assertIn("vision_indices", output[0]) + self.assertNotIn("audio_indices", output[0]) + self.assertIn("vision_mask", output[0]) + self.assertNotIn("audio_mask", output[0]) + self.assertAllEqual(output[0]["token_ids"], expected_x["token_ids"]) + self.assertAllEqual( + output[0]["padding_mask"], expected_x["padding_mask"] + ) + self.assertAllEqual(output[1], expected_y) + self.assertAllEqual(output[2], expected_sw) + + def test_audio_preprocessor_basics(self): + input_data = { + "prompts": ["the quick "], + "responses": ["brown"], + "audios": [[np.ones((16000,))]], + } + preprocessor = Gemma3nCausalLMPreprocessor(**self.init_audio_kwargs) + output = preprocessor(input_data) + # Check that we have the right keys. + self.assertIn("token_ids", output[0]) + self.assertIn("padding_mask", output[0]) + self.assertIn("input_features", output[0]) + self.assertIn("input_features_mask", output[0]) + self.assertNotIn("images", output[0]) + self.assertNotIn("audios", output[0]) + self.assertNotIn("vision_indices", output[0]) + self.assertIn("audio_indices", output[0]) + self.assertNotIn("vision_mask", output[0]) + self.assertIn("audio_mask", output[0]) + self.assertEqual(output[0]["input_features"].shape[0], 1) + self.assertEqual(output[0]["input_features_mask"].shape[0], 1) + self.assertAllEqual(output[0]["input_features"].shape[0:2], [1, 2]) + self.assertGreater(output[0]["input_features"].shape[2], 0) + self.assertGreater(output[0]["input_features_mask"].shape[2], 0) + + def test_multimodal_preprocessor_basics(self): + input_data = { + "prompts": ["image audio "], + "responses": ["test"], + "images": [[np.ones((8, 8, 3))]], + "audios": [[np.ones((16000,))]], + } + preprocessor = Gemma3nCausalLMPreprocessor( + **self.init_multimodal_kwargs + ) + output = preprocessor(input_data) + # Check that we have all the right keys. + self.assertIn("token_ids", output[0]) + self.assertIn("padding_mask", output[0]) + self.assertIn("images", output[0]) + self.assertIn("input_features", output[0]) + self.assertIn("input_features_mask", output[0]) + self.assertIn("vision_indices", output[0]) + self.assertIn("audio_indices", output[0]) + self.assertIn("vision_mask", output[0]) + self.assertIn("audio_mask", output[0]) + self.assertNotIn("audios", output[0]) + # Check shapes for images. + self.assertAllEqual(output[0]["images"].shape, [1, 2, 4, 4, 3]) + # Check shapes for audios. + self.assertEqual(output[0]["input_features"].shape[0], 1) + self.assertEqual(output[0]["input_features_mask"].shape[0], 1) + self.assertAllEqual(output[0]["images"].shape[0:2], [1, 2]) + self.assertAllEqual(output[0]["input_features"].shape[0:2], [1, 2]) + self.assertGreater(output[0]["input_features"].shape[2], 0) + self.assertGreater(output[0]["input_features_mask"].shape[2], 0) + + def test_text_no_start_end_token(self): + input_data = { + "prompts": ["the quick brown fox"] * 4, + "responses": ["round"] * 4, + } + preprocessor = Gemma3nCausalLMPreprocessor( + **self.init_text_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[9, 14, 10, 12, 15, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[14, 10, 12, 15, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[0, 0, 0, 1, 0, 0, 0, 0]] * 4) + + def test_text_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = Gemma3nCausalLMPreprocessor(**self.init_text_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 9, 14, 10, 12, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_vision_generate_preprocess(self): + input_data = { + "prompts": "the quick brown fox ", + "images": np.ones((8, 8, 3)), + } + preprocessor = Gemma3nCausalLMPreprocessor(**self.init_vision_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual( + x["token_ids"], + [1, 9, 14, 10, 12, 16, 4] + [8] * 5 + [5, 16] + [0] * 6, + ) + self.assertAllEqual(x["padding_mask"], [1] * 14 + [0] * 6) + self.assertAllEqual(x["vision_indices"], list(range(7, 12)) + [0] * 5) + self.assertAllEqual(x["vision_mask"], [0] * 7 + [1] * 5 + [0] * 8) + self.assertAllEqual(x["images"].shape, [2, 4, 4, 3]) + + def test_audio_generate_preprocess(self): + input_data = { + "prompts": "the quick ", + "audios": np.ones((16000,)), + } + preprocessor = Gemma3nCausalLMPreprocessor(**self.init_audio_kwargs) + x = preprocessor.generate_preprocess(input_data) + # Check that we have the right keys. + self.assertIn("token_ids", x) + self.assertIn("audio_indices", x) + self.assertIn("audio_mask", x) + self.assertIn("audios", x) + self.assertIn("input_features", x) + self.assertIn("input_features_mask", x) + + def test_text_generate_postprocess(self): + input_data = { + "token_ids": [1, 9, 14, 10, 12, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = Gemma3nCausalLMPreprocessor(**self.init_text_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + def test_vision_generate_postprocess(self): + input_data = { + "token_ids": [1, 9, 14, 10, 12, 16, 4] + + [8] * 5 + + [5, 16] + + [0] * 6, + "padding_mask": [1] * 14 + [0] * 6, + } + preprocessor = Gemma3nCausalLMPreprocessor(**self.init_text_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox \n\n ") + + def test_invalid_shape(self): + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello world", "this is testing"], + "responses": [""], + } + self.text_preprocessor(input_data) + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello world", "this is testing"], + "responses": ["hello", "", ""], + } + self.text_preprocessor(input_data) + + def test_text_only_with_images_raises_error(self): + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello"], + "responses": ["world"], + "images": [np.ones((8, 8, 3))], + } + self.text_preprocessor(input_data) + + def test_text_only_with_audios_raises_error(self): + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello"], + "responses": ["world"], + "audios": [np.ones((16000,))], + } + self.text_preprocessor(input_data) diff --git a/keras_hub/src/models/gemma3n/gemma3n_causal_lm_test.py b/keras_hub/src/models/gemma3n/gemma3n_causal_lm_test.py new file mode 100644 index 0000000000..c8beeaae3a --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_causal_lm_test.py @@ -0,0 +1,364 @@ +import copy +from unittest.mock import patch + +import keras +import numpy as np +from absl.testing import parameterized +from keras import ops + +from keras_hub.src.models.gemma3n.gemma3n_audio_converter import ( + Gemma3nAudioConverter, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, +) +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.models.gemma3n.gemma3n_causal_lm import Gemma3nCausalLM +from keras_hub.src.models.gemma3n.gemma3n_causal_lm_preprocessor import ( + Gemma3nCausalLMPreprocessor, +) +from keras_hub.src.models.gemma3n.gemma3n_image_converter import ( + Gemma3nImageConverter, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + convert_arch_def_to_stackwise, +) +from keras_hub.src.tests.mocks.mock_gemma3n_tokenizer import ( + MockGemma3nTokenizer, +) +from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.keras_utils import fused_attention_op_available +from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op +from keras_hub.src.utils.keras_utils import running_on_gpu + + +class Gemma3nCausalLMTest(TestCase, parameterized.TestCase): + def setUp(self): + self.tokenizer = MockGemma3nTokenizer() + + # === Vision Encoder === + vision_arch_def = [["er_r1_k3_s1_e1_c8"]] + stackwise_params = convert_arch_def_to_stackwise(vision_arch_def) + vision_encoder = MobileNetV5Backbone( + **stackwise_params, + num_features=4, + image_shape=(16, 16, 3), + use_msfa=False, + ) + + # === Audio Encoder === + audio_encoder = Gemma3nAudioEncoder( + hidden_size=4, + input_feat_size=16, + sscp_conv_channel_size=[2, 4], + sscp_conv_kernel_size=[(1, 1), (1, 1)], + sscp_conv_stride_size=[(2, 2), (2, 2)], + sscp_conv_group_norm_eps=1e-5, + conf_num_hidden_layers=1, + rms_norm_eps=1e-6, + gradient_clipping=1.0, + conf_residual_weight=0.5, + conf_num_attention_heads=1, + conf_attention_chunk_size=2, + conf_attention_context_right=1, + conf_attention_context_left=1, + conf_attention_logit_cap=50.0, + conf_conv_kernel_size=3, + conf_reduction_factor=1, + ) + + # === Text-Only === + self.text_preprocessor = Gemma3nCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=None, + audio_converter=None, + sequence_length=20, + max_images_per_prompt=0, + num_vision_tokens_per_image=0, + max_audios_per_prompt=0, + num_audio_tokens_per_audio=0, + ) + text_backbone_init_kwargs = { + "text_vocab_size": self.text_preprocessor.tokenizer.vocabulary_size(), # noqa: E501 + "text_hidden_size": 4, + "num_hidden_layers": 1, + "pad_token_id": 0, + "num_attention_heads": 1, + "num_key_value_heads": 1, + "head_dim": 4, + "intermediate_size": [8], + "hidden_activation": "gelu_approximate", + "layer_types": ["full_attention"], + "sliding_window": 4, + "rope_theta": 10000.0, + "max_position_embeddings": 20, + "vocab_size_per_layer_input": 10, + "hidden_size_per_layer_input": 2, + "altup_num_inputs": 2, + "laurel_rank": 1, + } + self.text_backbone = Gemma3nBackbone(**text_backbone_init_kwargs) + self.text_init_kwargs = { + "preprocessor": self.text_preprocessor, + "backbone": self.text_backbone, + } + self.text_train_data = ( + { + "prompts": ["the quick brown fox", "the quick brown fox"], + "responses": ["the earth is round", "the earth is round"], + }, + ) + self.text_input_data = self.text_preprocessor(*self.text_train_data)[0] + + # === Vision + Text === + self.image_converter = Gemma3nImageConverter( + image_size=(16, 16), + ) + self.vision_preprocessor = Gemma3nCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=self.image_converter, + audio_converter=None, + sequence_length=20, + max_images_per_prompt=2, + num_vision_tokens_per_image=4, + max_audios_per_prompt=0, + num_audio_tokens_per_audio=0, + ) + vision_backbone_init_kwargs = copy.deepcopy(text_backbone_init_kwargs) + vision_backbone_init_kwargs["vision_encoder_config"] = ( + vision_encoder.get_config() + ) + vision_backbone_init_kwargs["vision_hidden_size"] = 8 + self.vision_backbone = Gemma3nBackbone(**vision_backbone_init_kwargs) + self.vision_init_kwargs = { + "preprocessor": self.vision_preprocessor, + "backbone": self.vision_backbone, + } + self.vision_train_data = ( + { + "prompts": [ + "the quick brown fox ", + "the quick brown fox", + ], + "responses": ["the earth is round", "the earth is round"], + "images": [np.ones((8, 8, 3)), np.ones((8, 8, 3))], + }, + ) + self.vision_input_data = self.vision_preprocessor( + *self.vision_train_data + )[0] + + # === Audio + Text === + self.audio_converter = Gemma3nAudioConverter( + feature_size=16, + sampling_rate=16000, + ) + self.audio_preprocessor = Gemma3nCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=None, + audio_converter=self.audio_converter, + sequence_length=20, + max_images_per_prompt=0, + num_vision_tokens_per_image=0, + max_audios_per_prompt=2, + num_audio_tokens_per_audio=3, + ) + audio_backbone_init_kwargs = copy.deepcopy(text_backbone_init_kwargs) + audio_backbone_init_kwargs["audio_encoder_config"] = ( + audio_encoder.get_config() + ) + audio_backbone_init_kwargs["audio_hidden_size"] = 4 + self.audio_backbone = Gemma3nBackbone(**audio_backbone_init_kwargs) + self.audio_init_kwargs = { + "preprocessor": self.audio_preprocessor, + "backbone": self.audio_backbone, + } + self.audio_train_data = ( + { + "prompts": [ + "the quick ", + "the quick brown fox", + ], + "responses": ["brown", "the earth is round"], + "audios": [np.ones((16000,)), np.ones((16000,))], + }, + ) + self.audio_input_data = self.audio_preprocessor(*self.audio_train_data)[ + 0 + ] + + # === Multimodal (Vision + Audio + Text) === + self.multimodal_preprocessor = Gemma3nCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=self.image_converter, + audio_converter=self.audio_converter, + sequence_length=30, + max_images_per_prompt=2, + num_vision_tokens_per_image=4, + max_audios_per_prompt=2, + num_audio_tokens_per_audio=3, + ) + multimodal_backbone_init_kwargs = copy.deepcopy( + text_backbone_init_kwargs + ) + multimodal_backbone_init_kwargs["vision_encoder_config"] = ( + vision_encoder.get_config() + ) + multimodal_backbone_init_kwargs["vision_hidden_size"] = 8 + multimodal_backbone_init_kwargs["audio_encoder_config"] = ( + audio_encoder.get_config() + ) + multimodal_backbone_init_kwargs["audio_hidden_size"] = 4 + multimodal_backbone_init_kwargs["max_position_embeddings"] = 30 + self.multimodal_backbone = Gemma3nBackbone( + **multimodal_backbone_init_kwargs + ) + self.multimodal_init_kwargs = { + "preprocessor": self.multimodal_preprocessor, + "backbone": self.multimodal_backbone, + } + self.multimodal_train_data = ( + { + "prompts": [ + "image audio ", + "the quick brown fox", + ], + "responses": ["test", "the earth is round"], + "images": [np.ones((8, 8, 3)), np.ones((8, 8, 3))], + "audios": [np.ones((16000,)), np.ones((16000,))], + }, + ) + self.multimodal_input_data = self.multimodal_preprocessor( + *self.multimodal_train_data + )[0] + + @parameterized.named_parameters( + ("text_only", "text_only"), + ("vision_text", "vision_text"), + ("audio_text", "audio_text"), + ("multimodal", "multimodal"), + ) + def test_causal_lm_basics(self, modality_type): + if modality_type == "text_only": + init_kwargs = self.text_init_kwargs + train_data = self.text_train_data + expected_vocab_size = self.tokenizer.vocabulary_size() + elif modality_type == "vision_text": + init_kwargs = self.vision_init_kwargs + train_data = self.vision_train_data + expected_vocab_size = self.tokenizer.vocabulary_size() + elif modality_type == "audio_text": + init_kwargs = self.audio_init_kwargs + train_data = self.audio_train_data + expected_vocab_size = self.tokenizer.vocabulary_size() + else: # multimodal + init_kwargs = self.multimodal_init_kwargs + train_data = self.multimodal_train_data + expected_vocab_size = self.tokenizer.vocabulary_size() + self.run_task_test( + cls=Gemma3nCausalLM, + init_kwargs=init_kwargs, + train_data=train_data, + expected_output_shape=( + 2, + 20 if modality_type != "multimodal" else 30, + expected_vocab_size, + ), + ) + + def test_text_flash_attention_call(self): + if ( + keras.config.backend() != "jax" + or not fused_attention_op_available() + or not gpu_supports_fused_attention_op() + ): + self.skipTest("`flash_attention` testing requires the JAX backend.") + + with patch("keras.src.backend.nn.dot_product_attention") as mock_func: + causal_lm = Gemma3nCausalLM(**self.text_init_kwargs) + causal_lm.generate("the quick brown fox") + if running_on_gpu(): + mock_func.assert_called() + else: + mock_func.assert_not_called() + + def test_text_early_stopping(self): + causal_lm = Gemma3nCausalLM(**self.text_init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.text_preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_text_multitoken_stopping(self): + causal_lm = Gemma3nCausalLM(**self.text_init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.text_preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + output = causal_lm.generate(prompt, stop_token_ids=(3,)) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_text_generate_compilation(self): + causal_lm = Gemma3nCausalLM(**self.text_init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + def test_vision_generate(self): + causal_lm = Gemma3nCausalLM(**self.vision_init_kwargs) + inputs = { + "prompts": "this is a lily ", + "images": np.ones((8, 8, 3), dtype="float32"), + } + output = causal_lm.generate(inputs) + self.assertIsInstance(output, str) + + def test_audio_generate(self): + causal_lm = Gemma3nCausalLM(**self.audio_init_kwargs) + inputs = { + "prompts": "transcribe this ", + "audios": np.ones((16000,), dtype="float32"), + } + output = causal_lm.generate(inputs) + self.assertIsInstance(output, str) + + def test_multimodal_generate(self): + causal_lm = Gemma3nCausalLM(**self.multimodal_init_kwargs) + inputs = { + "prompts": "image audio ", + "images": np.ones((8, 8, 3), dtype="float32"), + "audios": np.ones((16000,), dtype="float32"), + } + output = causal_lm.generate(inputs) + self.assertIsInstance(output, str) diff --git a/keras_hub/src/models/gemma3n/gemma3n_image_converter.py b/keras_hub/src/models/gemma3n/gemma3n_image_converter.py new file mode 100644 index 0000000000..1901e7db0e --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_image_converter.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone + + +@keras_hub_export("keras_hub.layers.Gemma3nImageConverter") +class Gemma3nImageConverter(ImageConverter): + backbone_cls = Gemma3nBackbone + + def __init__(self, **kwargs): + # Always do image preprocessing in float32 + kwargs.pop("dtype", None) + dtype = "float32" + super().__init__(dtype=dtype, **kwargs) diff --git a/keras_hub/src/models/gemma3n/gemma3n_text_decoder.py b/keras_hub/src/models/gemma3n/gemma3n_text_decoder.py new file mode 100644 index 0000000000..aff1276000 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_text_decoder.py @@ -0,0 +1,300 @@ +import math + +import keras + +from keras_hub.src.models.gemma3n.gemma3n_attention import Gemma3nTextAttention +from keras_hub.src.models.gemma3n.gemma3n_text_layers import Gemma3nTextAltUp +from keras_hub.src.models.gemma3n.gemma3n_text_layers import ( + Gemma3nTextLaurelBlock, +) +from keras_hub.src.models.gemma3n.gemma3n_text_layers import Gemma3nTextMLP +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nTextDecoderBlock(keras.layers.Layer): + """A layer that implements a single Gemma3n decoder block. + + This layer combines self-attention, feed-forward networks, and normalization + to process sequences. It includes specialized components like AltUp and + Laurel blocks for enhanced performance. + + Args: + hidden_size: int. The size of the hidden states. + rms_norm_eps: float. The epsilon value for the Gemma 3n RMS + normalization layers. + num_attention_heads: int. The number of attention heads. + num_key_value_heads: int. The number of key and value heads for + Grouped-Query Attention. + head_dim: int. The dimension of each attention head. + attention_bias: bool. If `True`, attention layers will use a bias. + attention_dropout: float. The dropout rate for the attention mechanism. + is_sliding: bool. If `True`, enables sliding window attention. + sliding_window: int. The size of the sliding window for attention. + intermediate_size: int. The size of the intermediate layer in the MLP. + hidden_activation: str. The activation function for the MLP. + activation_sparsity: float. Sparsity factor for the activation function. + altup_num_inputs: int. The number of inputs for the AltUp layer. + altup_coef_clip: float. Coefficient clipping value for the AltUp layer. + altup_active_idx: int. The index of the active prediction in the + AltUp layer. + altup_correct_scale: bool. Whether to scale the corrected output from + the AltUp layer. + laurel_rank: int. The rank for the Laurel block. + hidden_size_per_layer_input: int. The hidden size for the per-layer + input projection. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_bias, + attention_dropout, + is_sliding, + sliding_window, + intermediate_size, + hidden_activation, + activation_sparsity, + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + altup_correct_scale, + laurel_rank, + hidden_size_per_layer_input, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.is_sliding = is_sliding + self.sliding_window = sliding_window + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.activation_sparsity = activation_sparsity + self.altup_num_inputs = altup_num_inputs + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.altup_correct_scale = altup_correct_scale + self.laurel_rank = laurel_rank + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.attention = Gemma3nTextAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window if is_sliding else None, + name="attention", + dtype=self.dtype_policy, + ) + self.mlp = Gemma3nTextMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_activation=hidden_activation, + activation_sparsity=activation_sparsity, + name="mlp", + dtype=self.dtype_policy, + ) + self.input_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="input_layernorm", + dtype=self.dtype_policy, + ) + self.post_attention_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_attention_layernorm", + dtype=self.dtype_policy, + ) + self.pre_feedforward_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="pre_feedforward_layernorm", + dtype=self.dtype_policy, + ) + self.post_feedforward_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_feedforward_layernorm", + dtype=self.dtype_policy, + ) + self.altup = Gemma3nTextAltUp( + hidden_size=hidden_size, + altup_num_inputs=altup_num_inputs, + altup_coef_clip=altup_coef_clip, + altup_active_idx=altup_active_idx, + rms_norm_eps=rms_norm_eps, + altup_correct_scale=altup_correct_scale, + name="altup", + dtype=self.dtype_policy, + ) + self.laurel = Gemma3nTextLaurelBlock( + hidden_size=hidden_size, + laurel_rank=laurel_rank, + rms_norm_eps=rms_norm_eps, + name="laurel", + dtype=self.dtype_policy, + ) + self.per_layer_input_gate = keras.layers.Dense( + hidden_size_per_layer_input, + use_bias=False, + name="per_layer_input_gate", + dtype=self.dtype_policy, + ) + self.per_layer_projection = keras.layers.Dense( + hidden_size, + use_bias=False, + name="per_layer_projection", + dtype=self.dtype_policy, + ) + self.post_per_layer_input_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_per_layer_input_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + ( + hidden_states_shape, + _, + _, + per_layer_input_shape, + _, + ) = input_shape + active_prediction_shape = hidden_states_shape[1:] + self.input_layernorm.build(active_prediction_shape) + self.laurel.build(active_prediction_shape) + self.attention.build(active_prediction_shape) + self.post_attention_layernorm.build(active_prediction_shape) + self.pre_feedforward_layernorm.build(active_prediction_shape) + self.mlp.build(active_prediction_shape) + self.post_feedforward_layernorm.build(active_prediction_shape) + self.altup.build(hidden_states_shape) + self.per_layer_input_gate.build(active_prediction_shape) + self.per_layer_projection.build(per_layer_input_shape) + self.post_per_layer_input_norm.build(active_prediction_shape) + if self.hidden_activation == "gelu_approximate": + # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(self.hidden_activation) + super().build(input_shape) + + def call( + self, + inputs, + cache=None, + cache_update_index=0, + cache_update_mask=None, + training=False, + ): + ( + hidden_states, + position_embeddings_global, + position_embeddings_local, + per_layer_input, + attention_mask, + ) = inputs + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.altup_active_idx] + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + position_embeddings = ( + position_embeddings_local + if self.is_sliding + else position_embeddings_global + ) + if cache is not None: + attn, _, new_cache = self.attention( + active_prediction_normed, + position_embeddings, + attention_mask, + cache=cache, + cache_update_index=cache_update_index, + cache_update_mask=cache_update_mask, + training=training, + ) + else: + attn, _ = self.attention( + active_prediction_normed, + position_embeddings, + attention_mask, + training=training, + ) + attn = self.post_attention_layernorm(attn) + attn_gated = active_prediction + attn + attn_laurel = (attn_gated + laurel_output) / math.sqrt(2) + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + corrected_predictions = self.altup.correct( + predictions, attn_ffw_laurel_gated + ) + corrected_predictions_list = [ + corrected_predictions[i] + for i in range(corrected_predictions.shape[0]) + ] + first_prediction = corrected_predictions_list[self.altup_active_idx] + if self.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output( + first_prediction + ) + first_prediction_gated = self.per_layer_input_gate(first_prediction) + first_prediction_activated = self.act_fn(first_prediction_gated) + first_prediction_multiplied = ( + first_prediction_activated * per_layer_input + ) + first_prediction_projected = self.per_layer_projection( + first_prediction_multiplied + ) + first_prediction_normed = self.post_per_layer_input_norm( + first_prediction_projected + ) + for i in range(1, len(corrected_predictions_list)): + corrected_predictions_list[i] = ( + corrected_predictions_list[i] + first_prediction_normed + ) + output = keras.ops.stack(corrected_predictions_list, axis=0) + if cache is not None: + return output, new_cache + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "is_sliding": self.is_sliding, + "sliding_window": self.sliding_window, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "activation_sparsity": self.activation_sparsity, + "altup_num_inputs": self.altup_num_inputs, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "altup_correct_scale": self.altup_correct_scale, + "laurel_rank": self.laurel_rank, + "hidden_size_per_layer_input": self.hidden_size_per_layer_input, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_text_layers.py b/keras_hub/src/models/gemma3n/gemma3n_text_layers.py new file mode 100644 index 0000000000..16cb2d5f8a --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_text_layers.py @@ -0,0 +1,430 @@ +import keras +import numpy as np + +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nTextScaledWordEmbedding(keras.layers.Layer): + """A layer that computes scaled word embeddings for Gemma3n models. + + This layer performs a standard embedding lookup and then scales the + resulting vectors by a specified factor. + + Args: + num_embeddings: int. The size of the vocabulary. + embedding_dim: int. The dimension of the embedding vectors. + embed_scale: float. The scaling factor applied to the embeddings. + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + embed_scale=1.0, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.embed_scale = embed_scale + self.embedding = keras.layers.Embedding( + self.num_embeddings, + self.embedding_dim, + name="embedding", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.embedding.build(input_shape) + super().build(input_shape) + + def call(self, inputs): + return self.embedding(inputs) * self.embed_scale + + def get_config(self): + config = super().get_config() + config.update( + { + "num_embeddings": self.num_embeddings, + "embedding_dim": self.embedding_dim, + "embed_scale": self.embed_scale, + } + ) + return config + + +class Gemma3nTextRotaryEmbedding(keras.layers.Layer): + """A layer that computes rotary positional embeddings for Gemma3n models. + + This layer calculates the cosine and sine matrices for Rotary Positional + Embedding (RoPE), which are then applied to query and key tensors in the + attention mechanism to inject positional information. + + Args: + head_dim: int. The dimension of each attention head. + rope_theta: float. The base for the rotary frequency. + max_position_embeddings: int. The maximum sequence length that this + model might be used with. + rope_scaling: dict or `None`. Specifies the scaling strategy for RoPE. + base: float. The base value for the inverse frequency calculation. + """ + + def __init__( + self, + head_dim, + rope_theta, + max_position_embeddings, + rope_scaling, + base=10000, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.rope_scaling = rope_scaling + self.base = base + inv_freq = 1.0 / ( + self.base + ** (np.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim) + ) + self.inv_freq = keras.ops.convert_to_tensor(inv_freq) + self.attention_scaling = 1.0 + + def call(self, x, position_ids): + inv_freq_expanded = keras.ops.expand_dims( + keras.ops.expand_dims(self.inv_freq, 0), -1 + ) + inv_freq_expanded = keras.ops.repeat( + inv_freq_expanded, repeats=keras.ops.shape(position_ids)[0], axis=0 + ) + position_ids_expanded = keras.ops.expand_dims( + keras.ops.cast(position_ids, "float32"), 1 + ) + + freqs = keras.ops.transpose( + keras.ops.matmul(inv_freq_expanded, position_ids_expanded), + (0, 2, 1), + ) + emb = keras.ops.concatenate([freqs, freqs], axis=-1) + cos = keras.ops.cos(emb) * self.attention_scaling + sin = keras.ops.sin(emb) * self.attention_scaling + return keras.ops.cast(cos, x.dtype), keras.ops.cast(sin, x.dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "head_dim": self.head_dim, + "rope_theta": self.rope_theta, + "max_position_embeddings": self.max_position_embeddings, + "rope_scaling": self.rope_scaling, + "base": self.base, + } + ) + return config + + +class Gemma3nTextMLP(keras.layers.Layer): + """A Gemma3n-specific feed-forward network (MLP) layer. + + This layer implements the MLP block used in Gemma3n transformer layers, + featuring a gated linear unit (GLU) structure. It can also apply activation + sparsity using a Gaussian top-k mechanism. + + Args: + hidden_size: int. The dimension of the hidden state. + intermediate_size: int. The dimension of the intermediate layer in the + MLP. + hidden_activation: str or callable. The activation function to use. + activation_sparsity: float. The target sparsity for activations, + enabling the Gaussian top-k mechanism if greater than 0. + """ + + def __init__( + self, + hidden_size, + intermediate_size, + hidden_activation, + activation_sparsity, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.activation_sparsity = activation_sparsity + self.gate_proj = keras.layers.Dense( + intermediate_size, + use_bias=False, + name="gate_proj", + dtype=self.dtype_policy, + ) + self.up_proj = keras.layers.Dense( + intermediate_size, + use_bias=False, + name="up_proj", + dtype=self.dtype_policy, + ) + self.down_proj = keras.layers.Dense( + hidden_size, + use_bias=False, + name="down_proj", + dtype=self.dtype_policy, + ) + if hidden_activation == "gelu_approximate": + # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(hidden_activation) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + self.down_proj.build((None, self.intermediate_size)) + super().build(input_shape) + + def _gaussian_topk(self, inputs): + target_sparsity_tensor = keras.ops.convert_to_tensor( + self.activation_sparsity, dtype="float32" + ) + std_multiplier = keras.ops.erfinv( + 2 * target_sparsity_tensor - 1 + ) * keras.ops.sqrt(keras.ops.convert_to_tensor(2.0, dtype="float32")) + std_multiplier = keras.ops.cast(std_multiplier, dtype=inputs.dtype) + inputs_mean = keras.ops.mean(inputs, axis=-1, keepdims=True) + inputs_std = keras.ops.std(inputs, axis=-1, keepdims=True) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return keras.ops.relu(inputs - cutoff_x) + + def call(self, hidden_states): + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + activations = self.act_fn(gate_proj) + up_proj = self.up_proj(hidden_states) + down_proj = self.down_proj(activations * up_proj) + return down_proj + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "activation_sparsity": self.activation_sparsity, + } + ) + return config + + +class Gemma3nTextLaurelBlock(keras.layers.Layer): + """A Laurel block layer for the Gemma3n model. + + This layer implements a low-rank residual block which applies a + down-projection to a specified rank, followed by an up-projection. The + result is normalized and added back to the original input, forming a + residual connection. + + Args: + hidden_size: int. The dimension of the hidden state. + laurel_rank: int. The rank of the low-rank adaptation. + rms_norm_eps: float. The epsilon value for the RMS normalization layer. + """ + + def __init__( + self, hidden_size, laurel_rank, rms_norm_eps, dtype=None, **kwargs + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.laurel_rank = laurel_rank + self.rms_norm_eps = rms_norm_eps + self.linear_left = keras.layers.Dense( + laurel_rank, + use_bias=False, + name="linear_left", + dtype=self.dtype_policy, + ) + self.linear_right = keras.layers.Dense( + hidden_size, + use_bias=False, + name="linear_right", + dtype=self.dtype_policy, + ) + self.post_laurel_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_laurel_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.linear_left.build(input_shape) + self.linear_right.build((None, self.laurel_rank)) + self.post_laurel_norm.build(input_shape) + super().build(input_shape) + + def call(self, hidden_states): + laurel_hidden_states = self.linear_left(hidden_states) + laurel_hidden_states = self.linear_right(laurel_hidden_states) + normed_laurel_hidden_states = self.post_laurel_norm( + laurel_hidden_states + ) + return hidden_states + normed_laurel_hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "laurel_rank": self.laurel_rank, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + +class Gemma3nTextAltUp(keras.layers.Layer): + """An Alternating Update (AltUp) layer for the Gemma3n model. + + This layer implements the AltUp mechanism, which combines multiple input + modalities through a predict-and-correct cycle. It uses a router to compute + modality-specific coefficients for predicting and correcting hidden states. + + Args: + hidden_size: int. The dimension of the hidden state. + altup_num_inputs: int. The number of input modalities to the AltUp + block. + altup_coef_clip: float. The clipping value for coefficients. + altup_active_idx: int. The index of the currently active input. + rms_norm_eps: float. The epsilon value for the Gemma 3n RMS + normalization layers. + altup_correct_scale: bool. If `True`, enables a learnable scaling + factor on the corrected output. + """ + + def __init__( + self, + hidden_size, + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + rms_norm_eps, + altup_correct_scale, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.altup_num_inputs = altup_num_inputs + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.rms_norm_eps = rms_norm_eps + + self.altup_correct_scale = altup_correct_scale + self.correct_output_scale = None + self.correction_coefs = keras.layers.Dense( + self.altup_num_inputs, + use_bias=False, + name="correction_coefs", + dtype=self.dtype_policy, + ) + self.prediction_coefs = keras.layers.Dense( + self.altup_num_inputs**2, + use_bias=False, + name="prediction_coefs", + dtype=self.dtype_policy, + ) + self.modality_router = keras.layers.Dense( + self.altup_num_inputs, + use_bias=False, + name="modality_router", + dtype=self.dtype_policy, + ) + self.router_norm = Gemma3nRMSNorm( + self.hidden_size, + eps=self.rms_norm_eps, + name="router_norm", + dtype=self.dtype_policy, + ) + self.router_input_scale = self.hidden_size**-1.0 + + def build(self, input_shape): + if self.altup_correct_scale: + self.correct_output_scale = self.add_weight( + shape=(self.hidden_size,), + initializer="zeros", + trainable=True, + name="correct_output_scale", + dtype=self.dtype_policy.variable_dtype, + ) + router_input_shape = input_shape[1:] + self.router_norm.build(router_input_shape) + self.modality_router.build(router_input_shape) + coefs_input_shape = router_input_shape[:-1] + (self.altup_num_inputs,) + self.correction_coefs.build(coefs_input_shape) + self.prediction_coefs.build(coefs_input_shape) + super().build(input_shape) + + def compute_router_modalities(self, x): + router_inputs = self.router_norm(x) * self.router_input_scale + routed = self.modality_router(router_inputs) + return keras.ops.cast( + keras.ops.tanh(keras.ops.cast(routed, "float32")), x.dtype + ) + + def predict(self, hidden_states): + modalities = self.compute_router_modalities( + hidden_states[self.altup_active_idx] + ) + modalities_shape = keras.ops.shape(modalities) + reshape_shape = modalities_shape[:-1] + ( + self.altup_num_inputs, + self.altup_num_inputs, + ) + all_coefs = keras.ops.reshape( + self.prediction_coefs(modalities), + reshape_shape, + ) + all_coefs = keras.ops.transpose(all_coefs, (0, 1, 3, 2)) + predictions = keras.ops.matmul( + keras.ops.transpose(hidden_states, (1, 2, 3, 0)), all_coefs + ) + predictions = keras.ops.transpose(predictions, (3, 0, 1, 2)) + predictions += hidden_states + return predictions + + def correct(self, predictions, activated): + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.altup_active_idx] + innovation = keras.ops.repeat( + keras.ops.expand_dims(innovation, 0), self.altup_num_inputs, axis=0 + ) + all_coefs = self.correction_coefs(modalities) + 1.0 + all_coefs = keras.ops.expand_dims( + keras.ops.transpose(all_coefs, (2, 0, 1)), -1 + ) + corrected = innovation * all_coefs + corrected += predictions + return corrected + + def scale_corrected_output(self, corrected): + return corrected * self.correct_output_scale + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "altup_num_inputs": self.altup_num_inputs, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "rms_norm_eps": self.rms_norm_eps, + "altup_correct_scale": self.altup_correct_scale, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_text_model.py b/keras_hub/src/models/gemma3n/gemma3n_text_model.py new file mode 100644 index 0000000000..c5f1e80a5a --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_text_model.py @@ -0,0 +1,436 @@ +import math + +import keras + +from keras_hub.src.models.gemma3n.gemma3n_text_decoder import ( + Gemma3nTextDecoderBlock, +) +from keras_hub.src.models.gemma3n.gemma3n_text_layers import ( + Gemma3nTextRotaryEmbedding, +) +from keras_hub.src.models.gemma3n.gemma3n_text_layers import ( + Gemma3nTextScaledWordEmbedding, +) +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nTextModel(keras.layers.Layer): + """The core Gemma3n text model layer. + + This layer implements the transformer architecture of the Gemma3n model. + It includes token embeddings, multiple decoder blocks, and final + normalization. + + Args: + pad_token_id: int. The id for the padding token. + vocab_size: int. The size of the vocabulary. + hidden_size: int. The size of the hidden states. + num_hidden_layers: int. The number of hidden layers in the transformer. + rms_norm_eps: float. The epsilon value for the RMS normalization layers. + num_attention_heads: int. The number of attention heads. + num_key_value_heads: int. The number of key-value heads for GQA. + head_dim: int. The dimension of each attention head. + attention_bias: bool. Whether to use a bias in the attention mechanism. + attention_dropout: float. The dropout rate for the attention scores. + layer_types: list of str. The type of each layer, e.g., + "sliding_attention". + sliding_window: int. The sliding window size for sliding window + attention. + rope_theta: float. The base frequency for Rotary Positional Embeddings. + rope_scaling: float or None. The scaling factor for RoPE. + rope_local_base_freq: float. The base frequency for local RoPE. + max_position_embeddings: int. The maximum sequence length. + intermediate_size: list of int. The size of the intermediate layer in + each of the feed-forward networks. + hidden_activation: str. The activation function for the hidden layers. + activation_sparsity_pattern: list of float or None. The sparsity pattern + for activations. + altup_num_inputs: int. The number of inputs for the AltUp mechanism. + altup_coef_clip: float. The coefficient clipping value for AltUp. + altup_active_idx: int. The active index for AltUp. + altup_correct_scale: bool. Whether to correct scaling in AltUp. + laurel_rank: int. The rank for LAUREL factorization. + hidden_size_per_layer_input: int. The hidden size for per-layer inputs. + vocab_size_per_layer_input: int. The vocabulary size for per-layer + inputs. + num_kv_shared_layers: int. The number of shared key-value layers. + """ + + def __init__( + self, + pad_token_id, + vocab_size, + hidden_size, + num_hidden_layers, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_bias, + attention_dropout, + layer_types, + sliding_window, + rope_theta, + rope_scaling, + rope_local_base_freq, + max_position_embeddings, + intermediate_size, + hidden_activation, + activation_sparsity_pattern, + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + altup_correct_scale, + laurel_rank, + hidden_size_per_layer_input, + vocab_size_per_layer_input, + num_kv_shared_layers, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.layer_types = layer_types + self.sliding_window = sliding_window + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.rope_local_base_freq = rope_local_base_freq + self.max_position_embeddings = max_position_embeddings + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.activation_sparsity_pattern = activation_sparsity_pattern + self.altup_num_inputs = altup_num_inputs + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.altup_correct_scale = altup_correct_scale + self.laurel_rank = laurel_rank + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.num_kv_shared_layers = num_kv_shared_layers + self.padding_idx = pad_token_id + self.embed_tokens = Gemma3nTextScaledWordEmbedding( + vocab_size, + hidden_size, + embed_scale=hidden_size**0.5, + name="embed_tokens", + dtype=self.dtype_policy, + ) + if activation_sparsity_pattern is None: + self.activation_sparsity_pattern = [0.0] * num_hidden_layers + self.transformer_layers = [ + Gemma3nTextDecoderBlock( + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_bias, + attention_dropout, + layer_types[i] == "sliding_attention", + sliding_window, + intermediate_size[i], + hidden_activation, + self.activation_sparsity_pattern[i], + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + altup_correct_scale, + laurel_rank, + hidden_size_per_layer_input, + name=f"decoder_block_{i}", + dtype=self.dtype_policy, + ) + for i in range(num_hidden_layers) + ] + self.final_normalization = Gemma3nRMSNorm( + hidden_size, eps=rms_norm_eps, name="norm", dtype=self.dtype_policy + ) + self.rotary_emb = Gemma3nTextRotaryEmbedding( + head_dim, + rope_theta, + max_position_embeddings, + rope_scaling, + dtype=self.dtype_policy, + name="rotary_emb", + ) + self.rotary_emb_local = Gemma3nTextRotaryEmbedding( + head_dim, + rope_local_base_freq, + max_position_embeddings, + None, + dtype=self.dtype_policy, + name="rotary_emb_local", + ) + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + vocab_size_per_layer_input, + num_hidden_layers * hidden_size_per_layer_input, + embed_scale=hidden_size_per_layer_input**0.5, + name="embed_tokens_per_layer", + dtype=self.dtype_policy, + ) + self.per_layer_model_projection = keras.layers.Dense( + num_hidden_layers * hidden_size_per_layer_input, + use_bias=False, + name="per_layer_model_projection", + dtype=self.dtype_policy, + ) + self.per_layer_projection_norm = Gemma3nRMSNorm( + hidden_size_per_layer_input, + eps=rms_norm_eps, + name="per_layer_projection_norm", + dtype=self.dtype_policy, + ) + self.altup_projections = [ + keras.layers.Dense( + hidden_size, + use_bias=False, + name=f"altup_projection_{i}", + dtype=self.dtype_policy, + ) + for i in range(1, altup_num_inputs) + ] + self.altup_unembed_projections = [ + keras.layers.Dense( + hidden_size, + use_bias=False, + name=f"altup_unembed_projection_{i}", + dtype=self.dtype_policy, + ) + for i in range(1, altup_num_inputs) + ] + self.per_layer_projection_scale = hidden_size**-0.5 + self.per_layer_input_scale = 1.0 / math.sqrt(2.0) + + def build(self, input_shape): + if isinstance(input_shape, (list, tuple)) and isinstance( + input_shape[0], (list, tuple) + ): + input_ids_shape, _, inputs_embeds_shape, _ = input_shape + else: + input_ids_shape = input_shape + hidden_size = self.embed_tokens.embedding_dim + inputs_embeds_shape = input_ids_shape[:-1] + (hidden_size,) + self.embed_tokens.build(input_ids_shape) + self.embed_tokens_per_layer.build(input_ids_shape) + if not self.per_layer_model_projection.built: + self.per_layer_model_projection.build(inputs_embeds_shape) + per_layer_projection_norm_shape = ( + None, + None, + None, + self.hidden_size_per_layer_input, + ) + if not self.per_layer_projection_norm.built: + self.per_layer_projection_norm.build( + per_layer_projection_norm_shape + ) + for proj in self.altup_projections: + proj.build(inputs_embeds_shape) + for proj in self.altup_unembed_projections: + proj.build(inputs_embeds_shape) + decoder_hidden_states_shape = ( + self.altup_num_inputs, + ) + inputs_embeds_shape + decoder_per_layer_input_shape = input_ids_shape + ( + self.hidden_size_per_layer_input, + ) + decoder_input_shape = ( + decoder_hidden_states_shape, + None, # position_embeddings_global + None, # position_embeddings_local + decoder_per_layer_input_shape, + None, # attention_mask + ) + for layer in self.transformer_layers: + layer.build(decoder_input_shape) + self.final_normalization.build(inputs_embeds_shape) + super().build(input_shape) + + def get_per_layer_inputs(self, input_ids): + embeds = self.embed_tokens_per_layer(input_ids) + return keras.ops.reshape( + embeds, + keras.ops.shape(input_ids) + + (self.num_hidden_layers, self.hidden_size_per_layer_input), + ) + + def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None): + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = ( + per_layer_projection * self.per_layer_projection_scale + ) + per_layer_projection = keras.ops.reshape( + per_layer_projection, + keras.ops.shape(inputs_embeds)[:-1] + + (self.num_hidden_layers, self.hidden_size_per_layer_input), + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection + ) + if per_layer_inputs is None: + return per_layer_projection + return ( + per_layer_projection + per_layer_inputs + ) * self.per_layer_input_scale + + def token_embedding(self, inputs, reverse=False): + """Apply or reverse the token embedding. + + Args: + inputs: If `reverse=False`, token IDs to embed. If `reverse=True`, + hidden states to convert to logits. + reverse: bool. If False, performs embedding lookup. If True, + computes logits by projecting hidden states through + the transpose of the embedding matrix. + """ + if not reverse: + return self.embed_tokens(inputs) + else: + embedding_weights = self.embed_tokens.embedding.embeddings + logits = keras.ops.matmul( + inputs, keras.ops.transpose(embedding_weights) + ) + logits = logits / self.embed_tokens.embed_scale + return logits + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, (list, tuple)) and isinstance( + input_shape[0], (list, tuple) + ): + input_ids_shape = input_shape[0] + else: + input_ids_shape = input_shape + hidden_size = self.embed_tokens.embedding_dim + return input_ids_shape + (hidden_size,) + + def call( + self, + input_ids, + attention_mask, + inputs_embeds, + per_layer_inputs, + cache=None, + cache_update_index=0, + cache_update_mask=None, + ): + position_ids = keras.ops.expand_dims( + keras.ops.arange(0, keras.ops.shape(input_ids)[1]), 0 + ) + hidden_states_0 = inputs_embeds + cos_global, sin_global = self.rotary_emb(hidden_states_0, position_ids) + cos_local, sin_local = self.rotary_emb_local( + hidden_states_0, position_ids + ) + target_magnitude = keras.ops.sqrt( + keras.ops.mean(hidden_states_0**2, axis=-1, keepdims=True) + ) + epsilon = 1e-5 + temp_hidden_states = [hidden_states_0] + for proj in self.altup_projections: + altup_proj = proj(hidden_states_0) + new_magnitude = keras.ops.sqrt( + keras.ops.maximum( + keras.ops.mean(altup_proj**2, axis=-1, keepdims=True), + epsilon, + ) + ) + current_hidden_state = altup_proj * target_magnitude / new_magnitude + temp_hidden_states.append(current_hidden_state) + hidden_states = keras.ops.stack(temp_hidden_states, axis=0) + if cache is not None: + caches = [] + for i, decoder_layer in enumerate(self.transformer_layers): + per_layer_input = per_layer_inputs[:, :, i, :] + current_cache = cache[:, i, ...] + hidden_states, new_cache = decoder_layer( + ( + hidden_states, + (cos_global, sin_global), + (cos_local, sin_local), + per_layer_input, + attention_mask, + ), + cache=current_cache, + cache_update_index=cache_update_index, + cache_update_mask=cache_update_mask, + ) + caches.append(new_cache) + cache = keras.ops.stack(caches, axis=1) + else: + for i, decoder_layer in enumerate(self.transformer_layers): + per_layer_input = per_layer_inputs[:, :, i, :] + hidden_states = decoder_layer( + ( + hidden_states, + (cos_global, sin_global), + (cos_local, sin_local), + per_layer_input, + attention_mask, + ) + ) + target_magnitude = keras.ops.sqrt( + keras.ops.mean(hidden_states[0] ** 2, axis=-1, keepdims=True) + ) + temp_hidden_states = [hidden_states[0]] + for i, proj in enumerate(self.altup_unembed_projections): + altup_unemb_proj = proj(hidden_states[i + 1]) + new_magnitude = keras.ops.sqrt( + keras.ops.maximum( + keras.ops.mean(altup_unemb_proj**2, axis=-1, keepdims=True), + epsilon, + ) + ) + current_hidden_state = ( + altup_unemb_proj * target_magnitude / new_magnitude + ) + temp_hidden_states.append(current_hidden_state) + hidden_states = keras.ops.stack(temp_hidden_states) + hidden_states = keras.ops.mean(hidden_states, axis=0) + normalized = self.final_normalization(hidden_states) + if cache is not None: + return normalized, cache + return normalized + + def get_config(self): + config = super().get_config() + config.update( + { + "pad_token_id": self.pad_token_id, + "vocab_size": self.vocab_size, + "hidden_size": self.hidden_size, + "num_hidden_layers": self.num_hidden_layers, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "layer_types": self.layer_types, + "sliding_window": self.sliding_window, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "rope_local_base_freq": self.rope_local_base_freq, + "max_position_embeddings": self.max_position_embeddings, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "activation_sparsity_pattern": self.activation_sparsity_pattern, + "altup_num_inputs": self.altup_num_inputs, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "altup_correct_scale": self.altup_correct_scale, + "laurel_rank": self.laurel_rank, + "hidden_size_per_layer_input": self.hidden_size_per_layer_input, + "vocab_size_per_layer_input": self.vocab_size_per_layer_input, + "num_kv_shared_layers": self.num_kv_shared_layers, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_tokenizer.py b/keras_hub/src/models/gemma3n/gemma3n_tokenizer.py new file mode 100644 index 0000000000..191fde8579 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_tokenizer.py @@ -0,0 +1,95 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) + + +@keras_hub_export( + [ + "keras_hub.tokenizers.Gemma3nTokenizer", + "keras_hub.models.Gemma3nTokenizer", + ] +) +class Gemma3nTokenizer(SentencePieceTokenizer): + """Gemma3n tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Gemma3n models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Gemma3n preset. + + If input is a batch of strings `(rank > 0)`, the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string `(rank == 0)`, the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_hub.models.Gemma3nTokenizer.from_preset( + "gemma3n_instruct_1b" + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_hub.models.Gemma3nTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + backbone_cls = Gemma3nBackbone + + def __init__(self, proto, **kwargs): + # Add special tokens. + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + # Image. + self._add_special_token("", "image_placeholder") + # Audio. + self._add_special_token("", "audio_placeholder") + # Multimodal inputs. + self._add_special_token("", "start_of_image_token") + self._add_special_token("", "end_of_image_token") + self._add_special_token("", "start_of_audio_token") + self._add_special_token("", "end_of_audio_token") + # Special tokens for conversation and masking. + self._add_special_token("", "start_of_turn_token") + self._add_special_token("", "end_of_turn_token") + self._add_special_token("", "mask_token") + self._add_special_token("[multimodal]", "multimodal_token") + super().__init__(proto=proto, **kwargs) diff --git a/keras_hub/src/models/gemma3n/gemma3n_tokenizer_test.py b/keras_hub/src/models/gemma3n/gemma3n_tokenizer_test.py new file mode 100644 index 0000000000..b2fa36b404 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_tokenizer_test.py @@ -0,0 +1,32 @@ +import os + +from keras_hub.src.models.gemma3n.gemma3n_tokenizer import Gemma3nTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Gemma3nTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using `create_gemma3n_test_proto.py`. + "proto": os.path.join( + self.get_test_data_dir(), "gemma3n_test_vocab.spm" + ) + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=Gemma3nTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[14, 19, 15, 17], [14, 16, 18, 20]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + Gemma3nTokenizer( + # Generated using `create_no_special_token_proto.py` + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) diff --git a/keras_hub/src/models/gemma3n/gemma3n_utils.py b/keras_hub/src/models/gemma3n/gemma3n_utils.py new file mode 100644 index 0000000000..0db8706d63 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_utils.py @@ -0,0 +1,122 @@ +import keras + + +def rotate_half(x): + """Rotates half of the hidden dimensions of the input tensor. + + This function is used to implement rotary positional embeddings. It splits + the last dimension of the input tensor into two halves, negates the second + half, and then concatenates them back together. + + Args: + x: The input tensor. + + Returns: + A new tensor with the second half of the last dimension rotated. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return keras.ops.concatenate([-x2, x1], axis=-1) + + +def repeat_kv(hidden_states, n_rep): + """Repeats the key and value states for Grouped-Query Attention. + + This function is used in Grouped-Query Attention (GQA) to expand the key + and value states to match the number of query heads. + + Args: + hidden_states: The key or value tensor to be repeated, with a shape of + `[batch, num_key_value_heads, seq_len, head_dim]`. + n_rep: int. The number of times to repeat the key/value heads. + + Returns: + The repeated tensor with a shape of + `[batch, num_key_value_heads * n_rep, seq_len, head_dim]`. + """ + if n_rep == 1: + return hidden_states + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = keras.ops.expand_dims(hidden_states, 2) + hidden_states = keras.ops.repeat(hidden_states, n_rep, axis=2) + return keras.ops.reshape( + hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim) + ) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + """Applies rotary positional embedding to the input tensor. + + Args: + x: The input tensor. + cos: The cosine part of the rotary embedding. + sin: The sine part of the rotary embedding. + unsqueeze_dim: int. The dimension to unsqueeze `cos` and `sin` before + applying the embedding. Defaults to 1. + + Returns: + The tensor with rotary positional embeddings applied. + """ + cos = keras.ops.expand_dims(cos, axis=unsqueeze_dim) + sin = keras.ops.expand_dims(sin, axis=unsqueeze_dim) + return (x * cos) + (rotate_half(x) * sin) + + +def eager_attention_forward( + query, + key, + value, + num_key_value_groups, + head_dim, + attention_mask, + dropout=0.0, + scaling=None, + softcap=None, + training=False, +): + """Forward pass for an eager attention implementation. + + Args: + query: The query tensor. + key: The key tensor. + value: The value tensor. + num_key_value_groups: int. The number of key-value groups. + head_dim: int. The dimension of each attention head. + attention_mask: The attention mask to apply. + dropout: float. The dropout rate. Defaults to 0.0. + scaling: float, optional. The scaling factor for attention scores. + If `None`, it defaults to `head_dim**-0.5`. + softcap: float, optional. A softcap value to apply to attention weights. + Defaults to `None`. + training: bool. Whether the model is in training mode. Defaults to + `False`. + """ + if scaling is None: + scaling = head_dim**-0.5 + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + attn_weights = ( + keras.ops.matmul(query, keras.ops.transpose(key_states, (0, 1, 3, 2))) + * scaling + ) + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + keras.ops.cast( + causal_mask, dtype=attn_weights.dtype + ) + attn_weights_dtype = attn_weights.dtype + attn_weights = keras.ops.softmax( + keras.ops.cast(attn_weights, "float32"), axis=-1 + ) + attn_weights = keras.ops.cast(attn_weights, attn_weights_dtype) + if training: + attn_weights = keras.layers.Dropout(dropout)( + attn_weights, training=training + ) + attn_output = keras.ops.matmul(attn_weights, value_states) + attn_output = keras.ops.transpose(attn_output, (0, 2, 1, 3)) + return attn_output, attn_weights diff --git a/keras_hub/src/models/gemma3n/rms_normalization.py b/keras_hub/src/models/gemma3n/rms_normalization.py new file mode 100644 index 0000000000..48955699d0 --- /dev/null +++ b/keras_hub/src/models/gemma3n/rms_normalization.py @@ -0,0 +1,67 @@ +import keras + + +class Gemma3nRMSNorm(keras.layers.Layer): + """The Gemma 3n specific RMS normalization layer. + + Args: + dim: int. The dimension of the input tensor. + eps: float. A small constant added to the denominator for numerical + stability. Defaults to `1e-6`. + with_scale: bool. Whether to include a learnable scaling parameter. + Defaults to `True`. + """ + + def __init__(self, dim, eps=1e-6, with_scale=True, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.eps = eps + self.with_scale = with_scale + + def build(self, input_shape): + if self.with_scale: + self.scale = self.add_weight( + shape=(self.dim,), + initializer="ones", + trainable=True, + name="scale", + dtype=self.dtype_policy.variable_dtype, + ) + else: + self.scale = 1.0 + super().build(input_shape) + + def call(self, x): + norm_x = x * keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x), axis=-1, keepdims=True) + + self.eps + ) + return norm_x * self.scale + + def _int8_call(self, x): + x = keras.ops.cast(x, "float32") + norm_x = x * keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x), axis=-1, keepdims=True) + + self.eps + ) + norm_x = norm_x * self.scale + return keras.ops.cast(norm_x, x.dtype) + + def _float8_call(self, x): + x_calc = keras.ops.cast(x, "float32") + norm_x = x_calc * keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x_calc), axis=-1, keepdims=True) + + self.eps + ) + return keras.ops.cast(norm_x * self.scale, x.dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "eps": self.eps, + "with_scale": self.with_scale, + } + ) + return config diff --git a/keras_hub/src/tests/mocks/mock_gemma3n_tokenizer.py b/keras_hub/src/tests/mocks/mock_gemma3n_tokenizer.py new file mode 100644 index 0000000000..9c770662b4 --- /dev/null +++ b/keras_hub/src/tests/mocks/mock_gemma3n_tokenizer.py @@ -0,0 +1,159 @@ +import tensorflow as tf + +from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch +from keras_hub.src.utils.tensor_utils import is_int_dtype +from keras_hub.src.utils.tensor_utils import is_string_dtype +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +class MockGemma3nTokenizer(Tokenizer): + def __init__( + self, + proto=None, + sequence_length=None, + dtype="int32", + add_bos=False, + add_eos=False, + **kwargs, + ): + if not is_int_dtype(dtype) and not is_string_dtype(dtype): + raise ValueError( + "Output dtype must be an integer type or a string. " + f"Received: dtype={dtype}" + ) + super().__init__(dtype=dtype, **kwargs) + self.vocabulary = [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "the", + "brown", + "earth", + "fox", + "is", + "quick", + "round", + "\n\n", + "", + "", + "", + "", + "[multimodal]", + ] + self.string_to_id = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer( + self.vocabulary, list(range(len(self.vocabulary))) + ), + default_value=3, + ) + self.id_to_string = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer( + list(range(len(self.vocabulary))), self.vocabulary + ), + default_value="", + ) + # The usual tokens. + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + # Image placeholder token. + self._add_special_token("", "image_placeholder") + # Audio placeholder token. + self._add_special_token("", "audio_placeholder") + # Tokens used in the preprocessor for multimodal inputs. + self._add_special_token("", "start_of_image_token") + self._add_special_token("", "end_of_image_token") + self._add_special_token("", "start_of_audio_token") + self._add_special_token("", "end_of_audio_token") + # Additional special tokens for conversation and masking. + self._add_special_token("", "start_of_turn_token") + self._add_special_token("", "end_of_turn_token") + self._add_special_token("", "mask_token") + self._add_special_token("[multimodal]", "multimodal_token") + self.sequence_length = sequence_length + self.add_bos = add_bos + self.add_eos = add_eos + + def vocabulary_size(self): + return len(self.vocabulary) + + def get_vocabulary(self): + return self.vocabulary + + def id_to_token(self, id): + return self.vocabulary[id] + + def token_to_id(self, token): + return self.vocabulary.index(token) + + @preprocessing_function + def tokenize(self, inputs): + inputs = tf.convert_to_tensor(inputs) + unbatched = inputs.shape.rank == 0 + if unbatched: + inputs = tf.expand_dims(inputs, 0) + # Add spaces around special tokens for proper splitting. + inputs = tf.strings.regex_replace( + inputs, self.start_of_image_token, f" {self.start_of_image_token} " + ) + inputs = tf.strings.regex_replace( + inputs, self.end_of_image_token, f" {self.end_of_image_token} " + ) + inputs = tf.strings.regex_replace( + inputs, self.image_placeholder, f" {self.image_placeholder} " + ) + inputs = tf.strings.regex_replace( + inputs, self.start_of_audio_token, f" {self.start_of_audio_token} " + ) + inputs = tf.strings.regex_replace( + inputs, self.end_of_audio_token, f" {self.end_of_audio_token} " + ) + inputs = tf.strings.regex_replace( + inputs, self.audio_placeholder, f" {self.audio_placeholder} " + ) + inputs = tf.strings.regex_replace(inputs, " ", " ") + sep_inputs = tf.strings.split(inputs, sep=" ") + tokens = self.string_to_id.lookup(sep_inputs) + if self.add_bos: + bos_tensor = tf.fill( + value=self.start_token_id, + dims=tokens.shape.as_list()[0:1] + [1], + ) + tokens = tf.concat((bos_tensor, tokens), axis=-1) + if self.add_eos: + eos_tensor = tf.fill( + value=self.end_token_id, dims=tokens.shape.as_list()[0:1] + [1] + ) + tokens = tf.concat((tokens, eos_tensor), axis=-1) + # Convert to a dense output if input was a scalar. + if unbatched: + tokens = tf.squeeze(tokens, 0) + return tokens + + @preprocessing_function + def detokenize(self, inputs): + inputs, unbatched, _ = convert_to_ragged_batch(inputs) + # tf-text sentencepiece does not handle int64. + inputs = tf.cast(inputs, "int32") + outputs = self.id_to_string.lookup(inputs) + outputs = tf.strings.reduce_join(outputs, axis=-1, separator=" ") + for token in [ + self.start_token, + self.end_token, + self.pad_token, + ]: + outputs = tf.strings.regex_replace(outputs, token, "") + outputs = tf.strings.strip(outputs) + if unbatched: + outputs = tf.squeeze(outputs, 0) + return outputs + + def __call__(self, inputs): + return self.tokenize(inputs) diff --git a/keras_hub/src/tests/test_data/gemma3n_test_vocab.spm b/keras_hub/src/tests/test_data/gemma3n_test_vocab.spm new file mode 100644 index 0000000000..75920d86be Binary files /dev/null and b/keras_hub/src/tests/test_data/gemma3n_test_vocab.spm differ diff --git a/tools/checkpoint_conversion/convert_gemma3n_checkpoints.py b/tools/checkpoint_conversion/convert_gemma3n_checkpoints.py new file mode 100644 index 0000000000..a3e29bdfaf --- /dev/null +++ b/tools/checkpoint_conversion/convert_gemma3n_checkpoints.py @@ -0,0 +1,876 @@ +import gc +import os +import types + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image +from transformers import Gemma3nForConditionalGeneration +from transformers import Gemma3nProcessor + +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import ( + MobileAttention, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import ( + UniversalInvertedResidual, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + convert_arch_def_to_stackwise, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct + +PRESET_MAP = { + "gemma3n_e2b": "google/gemma-3n-E2B", +} +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) +flags.DEFINE_string( + "cache_dir", "./hf_cache", "Directory to cache Hugging Face downloads." +) +flags.mark_flag_as_required("preset") + + +MOBILENETV5_300M_ENC_ARCH_DEF = [ + # Stage 0: 128x128 in + [ + "er_r1_k3_s2_e4_c128", + "er_r1_k3_s1_e4_c128", + "er_r1_k3_s1_e4_c128", + ], + # Stage 1: 256x256 in + [ + "uir_r1_a3_k5_s2_e6_c256", + "uir_r1_a5_k0_s1_e4_c256", + "uir_r1_a3_k0_s1_e4_c256", + "uir_r1_a5_k0_s1_e4_c256", + "uir_r1_a3_k0_s1_e4_c256", + ], + # Stage 2: 640x640 in + [ + "uir_r1_a5_k5_s2_e6_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a0_k0_s1_e1_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + ], + # Stage 3: 1280x1280 in + [ + "uir_r1_a5_k5_s2_e6_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + ], +] +mobilenetv5_config = convert_arch_def_to_stackwise( + MOBILENETV5_300M_ENC_ARCH_DEF +) +mobilenetv5_config.update( + { + "stem_size": 64, + "num_features": 2048, + "norm_layer": "rms_norm", + "act_layer": "gelu", + "use_msfa": True, + "layer_scale_init_value": 1e-5, + } +) +MODEL_CONFIGS = {"mobilenetv5_300m_enc": mobilenetv5_config} + + +def convert_model(hf_config, dtype=None): + text_config = hf_config.text_config + vision_config = hf_config.vision_config + audio_config = hf_config.audio_config + vision_encoder_config = MODEL_CONFIGS["mobilenetv5_300m_enc"].copy() + vision_encoder_config["image_shape"] = (768, 768, 3) + if text_config.hidden_activation == "gelu_pytorch_tanh": + text_config.hidden_activation = "gelu_approximate" + gemma3n_backbone = Gemma3nBackbone( + text_vocab_size=text_config.vocab_size, + text_hidden_size=text_config.hidden_size, + num_hidden_layers=text_config.num_hidden_layers, + pad_token_id=0, + num_attention_heads=text_config.num_attention_heads, + num_key_value_heads=text_config.num_key_value_heads, + head_dim=text_config.head_dim, + intermediate_size=text_config.intermediate_size, + hidden_activation=text_config.hidden_activation, + layer_types=text_config.layer_types, + sliding_window=text_config.sliding_window, + rope_theta=text_config.rope_theta, + max_position_embeddings=text_config.max_position_embeddings, + vocab_size_per_layer_input=text_config.vocab_size_per_layer_input, + hidden_size_per_layer_input=text_config.hidden_size_per_layer_input, + altup_num_inputs=text_config.altup_num_inputs, + laurel_rank=text_config.laurel_rank, + attention_bias=text_config.attention_bias, + attention_dropout=text_config.attention_dropout, + rope_scaling=text_config.rope_scaling, + activation_sparsity_pattern=text_config.activation_sparsity_pattern, + altup_coef_clip=text_config.altup_coef_clip, + altup_active_idx=text_config.altup_active_idx, + altup_correct_scale=text_config.altup_correct_scale, + num_kv_shared_layers=text_config.num_kv_shared_layers, + vision_encoder_config=vision_encoder_config, + vision_hidden_size=vision_config.hidden_size, + vision_vocab_size=vision_config.vocab_size, + vision_vocab_offset=vision_config.vocab_offset, + vision_soft_tokens_per_image=hf_config.vision_soft_tokens_per_image, + image_token_id=hf_config.image_token_id, + audio_encoder_config=audio_config.to_dict(), + audio_hidden_size=audio_config.hidden_size, + audio_vocab_size=audio_config.vocab_size, + audio_vocab_offset=audio_config.vocab_offset, + audio_soft_tokens_per_image=hf_config.audio_soft_tokens_per_image, + audio_token_id=hf_config.audio_token_id, + rms_norm_eps=text_config.rms_norm_eps, + dtype=dtype, + ) + return gemma3n_backbone + + +class HfToKerasConverter: + def __init__(self, hf_model): + self.hf_state_dict = { + k: v for k, v in hf_model.state_dict().items() if "lm_head" not in k + } + + def _port_weights(self, layer_or_variable, hf_key, transpose_dims=None): + if hf_key not in self.hf_state_dict: + print(f"⚠️ Weight key not found in state_dict: {hf_key}") + return + weights = self.hf_state_dict[hf_key].cpu().float().numpy() + if transpose_dims: + weights = weights.transpose(transpose_dims) + + if hasattr(layer_or_variable, "assign"): + layer_or_variable.assign(weights) + return + + current_weights = layer_or_variable.get_weights() + if ( + not current_weights + and hasattr(layer_or_variable, "weights") + and not layer_or_variable.weights + ): + print( + f"⚠️ Keras layer {layer_or_variable.name} has no weights to " + "set. Skipping." + ) + return + if len(current_weights) == 1: + layer_or_variable.set_weights([weights]) + elif len(current_weights) == 2: + bias_key = hf_key.replace(".weight", ".bias") + if bias_key in self.hf_state_dict: + bias = self.hf_state_dict[bias_key].cpu().numpy() + layer_or_variable.set_weights([weights, bias]) + else: + layer_or_variable.set_weights([weights, current_weights[1]]) + else: + print( + f"❌ Unexpected number of weights in layer " + f"{layer_or_variable.name}" + ) + + def _port_rms_norm(self, layer, hf_prefix): + key = f"{hf_prefix}.weight" + self._port_weights(layer, key) + + def _port_bn(self, layer, hf_prefix): + keys = [ + f"{hf_prefix}.weight", + f"{hf_prefix}.bias", + f"{hf_prefix}.running_mean", + f"{hf_prefix}.running_var", + ] + weights = [ + self.hf_state_dict[key].cpu().float().numpy() for key in keys + ] + layer.set_weights(weights) + + def _port_cna(self, cna_layer: ConvNormAct, hf_conv_prefix, hf_norm_prefix): + if isinstance(cna_layer.conv, keras.layers.DepthwiseConv2D): + self._port_weights( + cna_layer.conv, + f"{hf_conv_prefix}.weight", + transpose_dims=(2, 3, 0, 1), + ) + else: + self._port_weights( + cna_layer.conv, + f"{hf_conv_prefix}.weight", + transpose_dims=(2, 3, 1, 0), + ) + if f"{hf_norm_prefix}.running_mean" in self.hf_state_dict: + self._port_bn(cna_layer.norm, hf_norm_prefix) + else: + self._port_rms_norm(cna_layer.norm, hf_norm_prefix) + + def _port_attn(self, attn_layer, hf_attn_prefix): + self._port_weights( + attn_layer.query_layers[-1], + f"{hf_attn_prefix}.query.proj.weight", + (2, 3, 1, 0), + ) + if len(attn_layer.key_layers) > 1: + self._port_weights( + attn_layer.key_layers[0], + f"{hf_attn_prefix}.key.down_conv.weight", + (2, 3, 0, 1), + ) + key_norm_layer = attn_layer.key_layers[1] + if f"{hf_attn_prefix}.key.norm.running_mean" in self.hf_state_dict: + self._port_bn(key_norm_layer, f"{hf_attn_prefix}.key.norm") + else: + self._port_rms_norm( + key_norm_layer, f"{hf_attn_prefix}.key.norm" + ) + self._port_weights( + attn_layer.key_layers[-1], + f"{hf_attn_prefix}.key.proj.weight", + (2, 3, 1, 0), + ) + if len(attn_layer.value_layers) > 1: + self._port_weights( + attn_layer.value_layers[0], + f"{hf_attn_prefix}.value.down_conv.weight", + (2, 3, 0, 1), + ) + value_norm_layer = attn_layer.value_layers[1] + if ( + f"{hf_attn_prefix}.value.norm.running_mean" + in self.hf_state_dict + ): + self._port_bn(value_norm_layer, f"{hf_attn_prefix}.value.norm") + else: + self._port_rms_norm( + value_norm_layer, f"{hf_attn_prefix}.value.norm" + ) + self._port_weights( + attn_layer.value_layers[-1], + f"{hf_attn_prefix}.value.proj.weight", + (2, 3, 1, 0), + ) + self._port_weights( + attn_layer.output_proj_layers[-2], + f"{hf_attn_prefix}.output.proj.weight", + (2, 3, 1, 0), + ) + + def _port_vision_tower(self, keras_model): + print(" -> Porting vision tower (MobileNetV5)...") + backbone = keras_model.vision_encoder + hf_prefix = "model.vision_tower.timm_model" + + stem_layer = backbone.get_layer("conv_stem") + self._port_cna( + stem_layer, + f"{hf_prefix}.conv_stem.conv", + f"{hf_prefix}.conv_stem.bn", + ) + + block_layers = [ + layer + for layer in backbone.layers + if isinstance( + layer, + (EdgeResidual, UniversalInvertedResidual, MobileAttention), + ) + ] + block_counter = 0 + for stack_idx in range(len(backbone.stackwise_num_blocks)): + for block_idx_in_stage in range( + backbone.stackwise_num_blocks[stack_idx] + ): + block = block_layers[block_counter] + block_prefix = ( + f"{hf_prefix}.blocks.{stack_idx}.{block_idx_in_stage}" + ) + if isinstance(block, EdgeResidual): + self._port_cna( + block.conv_exp, + f"{block_prefix}.conv_exp", + f"{block_prefix}.bn1", + ) + self._port_cna( + block.conv_pwl, + f"{block_prefix}.conv_pwl", + f"{block_prefix}.bn2", + ) + elif isinstance(block, UniversalInvertedResidual): + if hasattr(block, "dw_start") and not isinstance( + block.dw_start, types.FunctionType + ): + self._port_cna( + block.dw_start, + f"{block_prefix}.dw_start.conv", + f"{block_prefix}.dw_start.bn", + ) + self._port_cna( + block.pw_exp, + f"{block_prefix}.pw_exp.conv", + f"{block_prefix}.pw_exp.bn", + ) + if hasattr(block, "dw_mid") and not isinstance( + block.dw_mid, types.FunctionType + ): + self._port_cna( + block.dw_mid, + f"{block_prefix}.dw_mid.conv", + f"{block_prefix}.dw_mid.bn", + ) + self._port_cna( + block.pw_proj, + f"{block_prefix}.pw_proj.conv", + f"{block_prefix}.pw_proj.bn", + ) + gamma_key = f"{block_prefix}.layer_scale.gamma" + if gamma_key in self.hf_state_dict: + self._port_weights(block.layer_scale, gamma_key) + elif isinstance(block, MobileAttention): + self._port_rms_norm(block.norm, f"{block_prefix}.norm") + gamma_key = f"{block_prefix}.layer_scale.gamma" + if gamma_key in self.hf_state_dict: + self._port_weights(block.layer_scale, gamma_key) + attn_prefix = f"{block_prefix}.attn" + self._port_attn(block.attn, attn_prefix) + block_counter += 1 + try: + msfa_layer = backbone.get_layer("msfa") + msfa_prefix = f"{hf_prefix}.msfa" + ffn = msfa_layer.ffn + self._port_cna( + ffn.pw_exp, + f"{msfa_prefix}.ffn.pw_exp.conv", + f"{msfa_prefix}.ffn.pw_exp.bn", + ) + self._port_cna( + ffn.pw_proj, + f"{msfa_prefix}.ffn.pw_proj.conv", + f"{msfa_prefix}.ffn.pw_proj.bn", + ) + self._port_rms_norm(msfa_layer.norm, f"{msfa_prefix}.norm") + except ValueError: + pass + + def _port_language_model(self, keras_model): + print(" -> Porting language model...") + lm = keras_model.language_model + hf_prefix = "model.language_model" + + self._port_weights( + lm.embed_tokens.embedding, f"{hf_prefix}.embed_tokens.weight" + ) + self._port_rms_norm(lm.final_normalization, f"{hf_prefix}.norm") + self._port_weights( + lm.embed_tokens_per_layer.embedding, + f"{hf_prefix}.embed_tokens_per_layer.weight", + ) + self._port_weights( + lm.per_layer_model_projection, + f"{hf_prefix}.per_layer_model_projection.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + lm.per_layer_projection_norm, + f"{hf_prefix}.per_layer_projection_norm", + ) + + for i, proj in enumerate(lm.altup_projections): + self._port_weights( + proj, + f"{hf_prefix}.altup_projections.{i}.weight", + transpose_dims=(1, 0), + ) + for i, proj in enumerate(lm.altup_unembed_projections): + self._port_weights( + proj, + f"{hf_prefix}.altup_unembed_projections.{i}.weight", + transpose_dims=(1, 0), + ) + + for i, layer in enumerate(lm.transformer_layers): + layer_prefix = f"{hf_prefix}.layers.{i}" + + # Attention + self._port_weights( + layer.attention.q_proj, + f"{layer_prefix}.self_attn.q_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.attention.k_proj, + f"{layer_prefix}.self_attn.k_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.attention.v_proj, + f"{layer_prefix}.self_attn.v_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.attention.o_proj, + f"{layer_prefix}.self_attn.o_proj.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + layer.attention.q_norm, f"{layer_prefix}.self_attn.q_norm" + ) + self._port_rms_norm( + layer.attention.k_norm, f"{layer_prefix}.self_attn.k_norm" + ) + + # MLP + self._port_weights( + layer.mlp.gate_proj, + f"{layer_prefix}.mlp.gate_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.mlp.up_proj, + f"{layer_prefix}.mlp.up_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.mlp.down_proj, + f"{layer_prefix}.mlp.down_proj.weight", + transpose_dims=(1, 0), + ) + + # LayerNorms + self._port_rms_norm( + layer.input_layernorm, f"{layer_prefix}.input_layernorm" + ) + self._port_rms_norm( + layer.post_attention_layernorm, + f"{layer_prefix}.post_attention_layernorm", + ) + self._port_rms_norm( + layer.pre_feedforward_layernorm, + f"{layer_prefix}.pre_feedforward_layernorm", + ) + self._port_rms_norm( + layer.post_feedforward_layernorm, + f"{layer_prefix}.post_feedforward_layernorm", + ) + + # AltUp + altup_prefix = f"{layer_prefix}.altup" + self._port_weights( + layer.altup.correction_coefs, + f"{altup_prefix}.correction_coefs.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.altup.prediction_coefs, + f"{altup_prefix}.prediction_coefs.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.altup.modality_router, + f"{altup_prefix}.modality_router.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + layer.altup.router_norm, f"{altup_prefix}.router_norm" + ) + if layer.altup.altup_correct_scale: + self._port_weights( + layer.altup.correct_output_scale, + f"{altup_prefix}.correct_output_scale", + ) + + # Laurel + laurel_prefix = f"{layer_prefix}.laurel" + self._port_weights( + layer.laurel.linear_left, + f"{laurel_prefix}.linear_left.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.laurel.linear_right, + f"{laurel_prefix}.linear_right.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + layer.laurel.post_laurel_norm, + f"{laurel_prefix}.post_laurel_norm", + ) + + # Per-layer inputs + self._port_weights( + layer.per_layer_input_gate, + f"{layer_prefix}.per_layer_input_gate.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + layer.per_layer_projection, + f"{layer_prefix}.per_layer_projection.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + layer.post_per_layer_input_norm, + f"{layer_prefix}.post_per_layer_input_norm", + ) + + def _port_audio_tower(self, keras_model): + print(" -> Porting audio tower...") + audio_encoder = keras_model.audio_encoder + hf_prefix = "model.audio_tower" + + ssp = audio_encoder.subsample_conv_projection + ssp_prefix = f"{hf_prefix}.subsample_conv_projection" + self._port_weights( + ssp.conv_0.conv, + f"{ssp_prefix}.conv_0.conv.weight", + transpose_dims=(2, 3, 1, 0), + ) + self._port_weights( + ssp.conv_0.norm.scale, f"{ssp_prefix}.conv_0.norm.weight" + ) + self._port_weights( + ssp.conv_1.conv, + f"{ssp_prefix}.conv_1.conv.weight", + transpose_dims=(2, 3, 1, 0), + ) + self._port_weights( + ssp.conv_1.norm.scale, f"{ssp_prefix}.conv_1.norm.weight" + ) + self._port_weights( + ssp.input_proj_linear, + f"{ssp_prefix}.input_proj_linear.weight", + transpose_dims=(1, 0), + ) + + for i, block in enumerate(audio_encoder.conformer): + block_prefix = f"{hf_prefix}.conformer.{i}" + ffw_start_prefix = f"{block_prefix}.ffw_layer_start" + self._port_rms_norm( + block.ffw_layer_start.pre_layer_norm, + f"{ffw_start_prefix}.pre_layer_norm", + ) + self._port_weights( + block.ffw_layer_start.ffw_layer_1, + f"{ffw_start_prefix}.ffw_layer_1.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + block.ffw_layer_start.ffw_layer_2, + f"{ffw_start_prefix}.ffw_layer_2.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + block.ffw_layer_start.post_layer_norm, + f"{ffw_start_prefix}.post_layer_norm", + ) + + attn_prefix = f"{block_prefix}.attention" + self._port_rms_norm( + block.attention.pre_attn_norm, f"{attn_prefix}.pre_attn_norm" + ) + self._port_weights( + block.attention.attn.per_dim_scale, + f"{attn_prefix}.attn.per_dim_scale", + ) + self._port_weights( + block.attention.attn.relative_position_embedding.pos_proj, + f"{attn_prefix}.attn.relative_position_embedding.pos_proj.weight", # noqa: E501 + transpose_dims=(1, 0), + ) + self._port_weights( + block.attention.attn.q_proj, + f"{attn_prefix}.attn.q_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + block.attention.attn.k_proj, + f"{attn_prefix}.attn.k_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + block.attention.attn.v_proj, + f"{attn_prefix}.attn.v_proj.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + block.attention.post, + f"{attn_prefix}.post.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + block.attention.post_norm, f"{attn_prefix}.post_norm" + ) + + lconv_prefix = f"{block_prefix}.lconv1d" + self._port_rms_norm( + block.lconv1d.pre_layer_norm, f"{lconv_prefix}.pre_layer_norm" + ) + self._port_weights( + block.lconv1d.linear_start, + f"{lconv_prefix}.linear_start.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + block.lconv1d.depthwise_conv1d, + f"{lconv_prefix}.depthwise_conv1d.weight", + transpose_dims=(2, 0, 1), + ) + self._port_rms_norm( + block.lconv1d.conv_norm, f"{lconv_prefix}.conv_norm" + ) + self._port_weights( + block.lconv1d.linear_end, + f"{lconv_prefix}.linear_end.weight", + transpose_dims=(1, 0), + ) + + ffw_end_prefix = f"{block_prefix}.ffw_layer_end" + self._port_rms_norm( + block.ffw_layer_end.pre_layer_norm, + f"{ffw_end_prefix}.pre_layer_norm", + ) + self._port_weights( + block.ffw_layer_end.ffw_layer_1, + f"{ffw_end_prefix}.ffw_layer_1.weight", + transpose_dims=(1, 0), + ) + self._port_weights( + block.ffw_layer_end.ffw_layer_2, + f"{ffw_end_prefix}.ffw_layer_2.weight", + transpose_dims=(1, 0), + ) + self._port_rms_norm( + block.ffw_layer_end.post_layer_norm, + f"{ffw_end_prefix}.post_layer_norm", + ) + self._port_rms_norm(block.norm, f"{block_prefix}.norm") + + def _port_multimodal_embedders(self, keras_model): + print(" -> Porting multimodal embedders...") + vision_prefix = "model.embed_vision" + self._port_weights( + keras_model.embed_vision.embedding, + f"{vision_prefix}.embedding.weight", + ) + self._port_rms_norm( + keras_model.embed_vision.hard_embedding_norm, + f"{vision_prefix}.hard_embedding_norm", + ) + self._port_rms_norm( + keras_model.embed_vision.soft_embedding_norm, + f"{vision_prefix}.soft_embedding_norm", + ) + self._port_weights( + keras_model.embed_vision.embedding_projection, + f"{vision_prefix}.embedding_projection.weight", + transpose_dims=(1, 0), + ) + + audio_prefix = "model.embed_audio" + self._port_weights( + keras_model.embed_audio.embedding, + f"{audio_prefix}.embedding.weight", + ) + self._port_rms_norm( + keras_model.embed_audio.hard_embedding_norm, + f"{audio_prefix}.hard_embedding_norm", + ) + self._port_rms_norm( + keras_model.embed_audio.soft_embedding_norm, + f"{audio_prefix}.soft_embedding_norm", + ) + self._port_weights( + keras_model.embed_audio.embedding_projection, + f"{audio_prefix}.embedding_projection.weight", + transpose_dims=(1, 0), + ) + + def convert(self, keras_model: Gemma3nBackbone): + print("🔶 Starting weight conversion...") + self._port_vision_tower(keras_model) + self._port_language_model(keras_model) + self._port_audio_tower(keras_model) + self._port_multimodal_embedders(keras_model) + print("✅ Full backbone weights converted.") + + +def validate_output(keras_model, hf_model, hf_processor): + print("🔶 Validating model outputs...") + image_size = hf_processor.image_processor.size + image = Image.new("RGB", (image_size["width"], image_size["height"])) + sampling_rate = hf_processor.feature_extractor.sampling_rate + audio_data = np.zeros(int(sampling_rate * 2.0)) + text = f"A cat sat on a mat{hf_processor.image_token}\n{hf_processor.audio_token}" # noqa: E501 + hf_inputs = hf_processor( + text=text, + images=image, + audio=[audio_data], + return_tensors="pt", + padding="longest", + ) + print(" -> Running HF model forward pass...") + with torch.no_grad(): + hf_output = hf_model.model(**hf_inputs).last_hidden_state + hf_output = hf_output.detach().cpu().float().numpy() + print(f" -> HF model output shape: {hf_output.shape}") + keras_inputs = {k: v.numpy() for k, v in hf_inputs.items()} + backbone_keras_inputs = {} + backbone_keras_inputs["token_ids"] = keras_inputs.pop("input_ids") + backbone_keras_inputs["padding_mask"] = keras_inputs.pop( + "attention_mask" + ).astype(bool) + # Images. + pixel_values = keras_inputs.pop("pixel_values") + pixel_values_transposed = np.transpose(pixel_values, (0, 2, 3, 1)) + if pixel_values_transposed.ndim == 4: + pixel_values_transposed = np.expand_dims( + pixel_values_transposed, axis=1 + ) + backbone_keras_inputs["images"] = pixel_values_transposed + # Audio. + input_features = keras_inputs.pop("input_features") + input_features_mask = keras_inputs.pop("input_features_mask") + if input_features.ndim == 3: + input_features = np.expand_dims(input_features, axis=1) + if input_features_mask.ndim == 2: + input_features_mask = np.expand_dims(input_features_mask, axis=1) + backbone_keras_inputs["input_features"] = input_features + backbone_keras_inputs["input_features_mask"] = input_features_mask + print(" -> Running Keras model forward pass...") + keras_output = keras_model.predict(backbone_keras_inputs) + print(f" -> Keras model output shape: {keras_output.shape}") + mean_diff = np.mean(np.abs(keras_output - hf_output)) + print(f"🔶 Mean absolute difference: {mean_diff}") + + +def main(_): + preset = FLAGS.preset + hf_model_name = PRESET_MAP[preset] + cache_dir = FLAGS.cache_dir + save_path = preset + model_cache_path = os.path.join(cache_dir, f"{preset}_model") + processor_cache_path = os.path.join(cache_dir, f"{preset}_processor") + hf_model = None + hf_processor = None + if os.path.exists(model_cache_path) and os.path.exists( + processor_cache_path + ): + print( + " -> Loading cached Hugging Face model and processor from " + f"{cache_dir}" + ) + try: + hf_model = Gemma3nForConditionalGeneration.from_pretrained( + model_cache_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + hf_processor = Gemma3nProcessor.from_pretrained( + processor_cache_path + ) + except Exception as e: + print(f"⚠️ Failed to load from cache: {e}. Downloading again...") + hf_model = None + hf_processor = None + if hf_model is None or hf_processor is None: + print(f" -> Downloading Hugging Face model: {hf_model_name}") + hf_model = Gemma3nForConditionalGeneration.from_pretrained( + hf_model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + hf_processor = Gemma3nProcessor.from_pretrained(hf_model_name) + print(f"💾 Saving model and processor to cache: {cache_dir}") + os.makedirs(cache_dir, exist_ok=True) + hf_model.save_pretrained(model_cache_path) + hf_processor.save_pretrained(processor_cache_path) + hf_model.eval() + print("-> Creating Keras model from HF config.") + keras_model = convert_model(hf_model.config, dtype="bfloat16") + print("-> Converting weights from HF to Keras.") + converter = HfToKerasConverter(hf_model) + converter.convert(keras_model) + print("\n-> Validating output consistency.") + validate_output(keras_model, hf_model, hf_processor) + print(f"💾 Saving Keras preset to ./{save_path}") + keras_model.save_to_preset(f"./{save_path}") + print("🎉 Conversion complete.") + del hf_model + gc.collect() + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/sentencepiece_testing/create_gemma3n_test_proto.py b/tools/sentencepiece_testing/create_gemma3n_test_proto.py new file mode 100644 index 0000000000..ea3e317873 --- /dev/null +++ b/tools/sentencepiece_testing/create_gemma3n_test_proto.py @@ -0,0 +1,34 @@ +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def main(): + train_sentencepiece( + ["the quick brown fox", "the earth is round"], + "gemma3n_test_vocab.spm", + vocab_size=21, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + control_symbols=[ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "[multimodal]", + ], + ) + + +if __name__ == "__main__": + main()