From c7509b533a19845841399842f20559d2ffef8e00 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 20 Oct 2025 10:19:33 -0700 Subject: [PATCH 1/4] Add Gemma3 Conversion script to port weights from HF directly --- .../src/utils/transformers/convert_gemma3.py | 335 ++++++++++++++++++ .../src/utils/transformers/preset_loader.py | 3 + .../convert_gemma3_hf_checkpoints.py | 165 +++++++++ 3 files changed, 503 insertions(+) create mode 100644 keras_hub/src/utils/transformers/convert_gemma3.py create mode 100644 tools/checkpoint_conversion/convert_gemma3_hf_checkpoints.py diff --git a/keras_hub/src/utils/transformers/convert_gemma3.py b/keras_hub/src/utils/transformers/convert_gemma3.py new file mode 100644 index 0000000000..799252002e --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_gemma3.py @@ -0,0 +1,335 @@ +import numpy as np + +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( + Gemma3VisionEncoder, +) +from keras_hub.src.utils.preset_utils import get_file + +backbone_cls = Gemma3Backbone + + +def load_image_converter_config(transformers_config): + if "vision_config" in transformers_config: + image_size = transformers_config["vision_config"].get("image_size", 224) + return { + "image_size": (image_size, image_size), + "scale": 1 / 127.5, + "offset": -1.0, + } + else: + return None + + +def convert_backbone_config(transformers_config): + if transformers_config["model_type"] == "gemma3_text": + image_size = None + vision_encoder = None + transformer_config = transformers_config + else: + image_size = transformers_config["vision_config"].get("image_size", 224) + vision_encoder_config = { + "image_size": image_size, + "patch_size": transformers_config["vision_config"].get( + "patch_size", 16 + ), + "num_heads": transformers_config["vision_config"].get( + "num_attention_heads", 12 + ), + "hidden_dim": transformers_config["vision_config"].get( + "hidden_size", 768 + ), + "num_layers": transformers_config["vision_config"].get( + "num_hidden_layers", 12 + ), + "intermediate_dim": transformers_config["vision_config"].get( + "intermediate_size", 3072 + ), + "output_dim": 2560, + "pool_size": 4, + "layer_norm_epsilon": transformers_config["vision_config"].get( + "layer_norm_eps", 1e-6 + ), + } + vision_encoder = Gemma3VisionEncoder(**vision_encoder_config) + transformer_config = transformers_config["text_config"] + + return { + "vocabulary_size": transformer_config.get( + "vocab_size", 262144 if vision_encoder is None else 262208 + ), + "image_size": image_size, + "num_layers": transformer_config.get("num_hidden_layers", 26), + "num_query_heads": transformer_config.get("num_attention_heads", 8), + "num_key_value_heads": transformer_config.get("num_key_value_heads", 4), + "hidden_dim": transformer_config.get("hidden_size", 2304), + "intermediate_dim": transformer_config.get("intermediate_size", 9216), + "head_dim": transformer_config.get("head_dim", 256), + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "attention_logit_softcap": transformer_config.get( + "attn_logit_softcap", None + ), + "final_logit_softcap": transformer_config.get( + "final_logit_softcap", None + ), + "use_sliding_window_attention": True, + "query_head_dim_normalize": True, + "sliding_window_size": transformer_config.get("sliding_window", 4096), + "local_rope_scaling_factor": 1.0, + "global_rope_scaling_factor": ( + transformer_config.get("rope_scaling") or {} + ).get("factor", 1.0), + "layer_norm_epsilon": transformer_config.get("rms_norm_eps", 1e-6), + "use_bidirectional_attention": transformer_config.get( + "use_bidirectional_attention", False + ), + "vision_encoder": vision_encoder, + } + + +def convert_weights(backbone, loader, transformers_config): + if transformers_config["model_type"] == "gemma3_text": + prefix = "model" + else: + prefix = "language_model.model" + + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key=f"{prefix}.embed_tokens.weight", + ) + + def transpose(x, shape): + return np.transpose(x) + + vision_encoder = backbone.vision_encoder + if vision_encoder is not None: + image_encoder = vision_encoder.get_layer("image_encoder") + + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel, + hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.patch_embedding.bias, + hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias", + ) + + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings, + hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight", + ) + + for i in range(image_encoder.num_layers): + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_1.gamma, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_1.beta, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[ + i + ].attn.query_proj.kernel, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.query_proj.bias, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.key_proj.kernel, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.key_proj.bias, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[ + i + ].attn.value_proj.kernel, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.value_proj.bias, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.out_proj.kernel, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.out_proj.bias, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias", + ) + + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_2.gamma, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_2.beta, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_1.kernel, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_1.bias, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_2.kernel, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_2.bias, + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias", + ) + + loader.port_weight( + keras_variable=image_encoder.encoder_layer_norm.gamma, + hf_weight_key="vision_tower.vision_model.post_layernorm.weight", + ) + loader.port_weight( + keras_variable=image_encoder.encoder_layer_norm.beta, + hf_weight_key="vision_tower.vision_model.post_layernorm.bias", + ) + + loader.port_weight( + keras_variable=vision_encoder.get_layer( + "vision_output_encoder" + ).vision_soft_embedding_norm.scale, + hf_weight_key="multi_modal_projector.mm_soft_emb_norm.weight", + ) + + loader.port_weight( + keras_variable=vision_encoder.get_layer( + "vision_output_encoder" + ).vision_input_projection.kernel, + hf_weight_key="multi_modal_projector.mm_input_projection_weight", + ) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"decoder_block_{i}") + + loader.port_weight( + keras_variable=decoder_layer.pre_attention_norm.scale, + hf_weight_key=f"{prefix}.layers.{i}.input_layernorm.weight", + ) + loader.port_weight( + keras_variable=decoder_layer.post_attention_norm.scale, + hf_weight_key=f"{prefix}.layers.{i}.post_attention_layernorm.weight", + ) + loader.port_weight( + keras_variable=decoder_layer.pre_ffw_norm.scale, + hf_weight_key=f"{prefix}.layers.{i}.pre_feedforward_layernorm.weight", + ) + loader.port_weight( + keras_variable=decoder_layer.post_ffw_norm.scale, + hf_weight_key=f"{prefix}.layers.{i}.post_feedforward_layernorm.weight", + ) + + # Attention layers + + ## Query + loader.port_weight( + keras_variable=decoder_layer.attention.query_dense.kernel, + hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[0], keras_shape[2], keras_shape[1]), + ), + axes=(0, 2, 1), + ), + ) + loader.port_weight( + keras_variable=decoder_layer.attention.query_norm.scale, + hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_norm.weight", + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer.attention.key_dense.kernel, + hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[0], keras_shape[2], keras_shape[1]), + ), + axes=(0, 2, 1), + ), + ) + loader.port_weight( + keras_variable=decoder_layer.attention.key_norm.scale, + hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_norm.weight", + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer.attention.value_dense.kernel, + hf_weight_key=f"{prefix}.layers.{i}.self_attn.v_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[0], keras_shape[2], keras_shape[1]), + ), + axes=(0, 2, 1), + ), + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer.attention.output_dense.kernel, + hf_weight_key=f"{prefix}.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[2], keras_shape[0], keras_shape[1]), + ), + axes=(1, 2, 0), + ), + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer.gating_ffw.kernel, + hf_weight_key=f"{prefix}.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer.gating_ffw_2.kernel, + hf_weight_key=f"{prefix}.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer.ffw_linear.kernel, + hf_weight_key=f"{prefix}.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("final_normalization").scale, + hf_weight_key=f"{prefix}.norm.weight", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 73f6a27717..c527e39655 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -11,6 +11,7 @@ from keras_hub.src.utils.transformers import convert_distilbert from keras_hub.src.utils.transformers import convert_esm from keras_hub.src.utils.transformers import convert_gemma +from keras_hub.src.utils.transformers import convert_gemma3 from keras_hub.src.utils.transformers import convert_gpt2 from keras_hub.src.utils.transformers import convert_llama3 from keras_hub.src.utils.transformers import convert_mistral @@ -46,6 +47,8 @@ def __init__(self, preset, config): self.converter = convert_esm elif model_type in ("gemma", "gemma2"): self.converter = convert_gemma + elif model_type in ("gemma3", "gemma3_text"): + self.converter = convert_gemma3 elif model_type == "gpt2": self.converter = convert_gpt2 elif model_type == "llama": diff --git a/tools/checkpoint_conversion/convert_gemma3_hf_checkpoints.py b/tools/checkpoint_conversion/convert_gemma3_hf_checkpoints.py new file mode 100644 index 0000000000..0a16b7e93a --- /dev/null +++ b/tools/checkpoint_conversion/convert_gemma3_hf_checkpoints.py @@ -0,0 +1,165 @@ +import os +import random +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app +from absl import flags + +random.seed(123) +torch.manual_seed(123) +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "gemma3_instruct_1b": "google/gemma-3-1b-it", + "gemma3_instruct_4b": "google/gemma-3-4b-it", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def test_model( + keras_hub_model, keras_hub_preprocessor, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_inputs = keras_hub_preprocessor.generate_preprocess( + ["What is Keras?"], sequence_length=6 + ) + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-2 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor.generate_preprocess( + ["What is Keras?"], sequence_length=6 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def validate_output( + keras_model, + hf_model, + hf_tokenizer, +): + input_str = "What is Keras?" + length = 32 + + # KerasHub + keras_output = keras_model.generate([input_str], max_length=length) + keras_output = keras_output[0] + print("🔶 KerasHub output:", keras_output) + + hf_inputs = hf_tokenizer([input_str], return_tensors="pt") + outputs = hf_model.generate( + **hf_inputs, + max_length=length, + do_sample=False, + num_beams=1, + pad_token_id=hf_tokenizer.pad_token_id, + ) + print("🔶 Huggingface generated token ids:", outputs[0]) + hf_generated_text = hf_tokenizer.batch_decode( + outputs, skip_special_tokens=True + )[0] + print("🔶 Huggingface output:", hf_generated_text) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, + device_map=device, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + + keras_hub_backbone = keras_hub.models.Gemma3Backbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_tokenizer = keras_hub.models.Gemma3Tokenizer.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_preprocessor = ( + keras_hub.models.Gemma3CausalLMPreprocessor.from_preset( + f"hf://{hf_preset}" + ) + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model( + keras_hub_backbone, keras_hub_preprocessor, hf_model, hf_tokenizer + ) + print("\n-> Tests passed!") + + gemma3_lm = keras_hub.models.Gemma3CausalLM( + backbone=keras_hub_backbone, + preprocessor=keras_hub_preprocessor, + sampler="greedy", + ) + + validate_output(gemma3_lm, hf_model, hf_tokenizer) + gemma3_lm.save_to_preset(f"./{preset}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From b1b18845accb2532a9c72e0b6b653caa630c3195 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 20 Oct 2025 12:23:08 -0700 Subject: [PATCH 2/4] add load_image_converter changes --- keras_hub/src/utils/transformers/preset_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index c527e39655..cd132466c7 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -115,5 +115,9 @@ def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) def load_image_converter(self, cls, **kwargs): + if hasattr(self.converter, "load_image_converter_config"): + config = self.converter.load_image_converter_config(self.config) + if config is not None: + return cls(**{**config, **kwargs}) # TODO: set image size for pali gemma checkpoints. return None From 936aff1974fa54d494c464c47a186bb42dac3968 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 23 Oct 2025 08:33:18 -0700 Subject: [PATCH 3/4] Enabling conditions for text-only models --- .../gemma3/gemma3_causal_lm_preprocessor.py | 11 ++++++-- .../src/models/gemma3/gemma3_tokenizer.py | 28 +++++++++++++------ .../src/utils/transformers/convert_gemma.py | 16 ++++++++++- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py index a60d095a2d..0296a5cc0a 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py @@ -283,9 +283,14 @@ def __init__( # is `None`. self.text_only_model = self.image_converter is None - self.image_placeholder = self.tokenizer.image_placeholder - self.start_of_image_token = self.tokenizer.start_of_image_token - self.end_of_image_token = self.tokenizer.end_of_image_token + if self.text_only_model: + self.image_placeholder = None + self.start_of_image_token = None + self.end_of_image_token = None + else: + self.image_placeholder = self.tokenizer.image_placeholder + self.start_of_image_token = self.tokenizer.start_of_image_token + self.end_of_image_token = self.tokenizer.end_of_image_token def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer diff --git a/keras_hub/src/models/gemma3/gemma3_tokenizer.py b/keras_hub/src/models/gemma3/gemma3_tokenizer.py index 0dfad50c08..4904c0e20d 100644 --- a/keras_hub/src/models/gemma3/gemma3_tokenizer.py +++ b/keras_hub/src/models/gemma3/gemma3_tokenizer.py @@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer): backbone_cls = Gemma3Backbone - def __init__(self, proto, **kwargs): + def __init__(self, proto, is_vision_model=True, **kwargs): # Add special tokens. + self.is_vision_model = is_vision_model # The usual tokens. self._add_special_token("", "start_token") self._add_special_token("", "end_token") self._add_special_token("", "pad_token") - # Image placeholder token. - self._add_special_token("", "image_placeholder") - - # Some tokens which are used in the preprocessor. We need to keep them - # here so that the preprocessor works with `tf.data`. - self._add_special_token("", "start_of_image_token") - self._add_special_token("", "end_of_image_token") + if is_vision_model: + # Image placeholder token. + self._add_special_token("", "image_placeholder") + # Some tokens which are used in the preprocessor. + # We need to keep them + # here so that the preprocessor works with tf.data. + self._add_special_token("", "start_of_image_token") + self._add_special_token("", "end_of_image_token") + else: + # For text-only, skip assigning token IDs or set to -1 + self.start_of_image_token_id = -1 + self.image_placeholder_token_id = -1 + self.end_of_image_token_id = -1 super().__init__(proto=proto, **kwargs) + + def get_config(self): + config = super().get_config() + config.update({"is_vision_model": self.is_vision_model}) + return config diff --git a/keras_hub/src/utils/transformers/convert_gemma.py b/keras_hub/src/utils/transformers/convert_gemma.py index d6ef141128..4c235c044c 100644 --- a/keras_hub/src/utils/transformers/convert_gemma.py +++ b/keras_hub/src/utils/transformers/convert_gemma.py @@ -1,4 +1,5 @@ import numpy as np +from sentencepiece import SentencePieceProcessor from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.utils.preset_utils import get_file @@ -157,4 +158,17 @@ def convert_weights(backbone, loader, transformers_config): def convert_tokenizer(cls, preset, **kwargs): - return cls(get_file(preset, "tokenizer.model"), **kwargs) + proto = get_file(preset, "tokenizer.model") + sp = SentencePieceProcessor() + if isinstance(proto, bytes): + sp.LoadFromSerializedProto(proto) + else: + sp.load(proto) + + is_vision_model = ( + sp.PieceToId("") != sp.unk_id() + and sp.PieceToId("") != sp.unk_id() + and sp.PieceToId("") != sp.unk_id() + ) + + return cls(proto, is_vision_model=is_vision_model, **kwargs) From f88f56ec4aaf39b1c4c28f4bdee660ce52cab72a Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 23 Oct 2025 08:41:39 -0700 Subject: [PATCH 4/4] Move tokenizer logic to gemma3 file --- .../src/utils/transformers/convert_gemma.py | 16 +--------------- .../src/utils/transformers/convert_gemma3.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/keras_hub/src/utils/transformers/convert_gemma.py b/keras_hub/src/utils/transformers/convert_gemma.py index 4c235c044c..d6ef141128 100644 --- a/keras_hub/src/utils/transformers/convert_gemma.py +++ b/keras_hub/src/utils/transformers/convert_gemma.py @@ -1,5 +1,4 @@ import numpy as np -from sentencepiece import SentencePieceProcessor from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.utils.preset_utils import get_file @@ -158,17 +157,4 @@ def convert_weights(backbone, loader, transformers_config): def convert_tokenizer(cls, preset, **kwargs): - proto = get_file(preset, "tokenizer.model") - sp = SentencePieceProcessor() - if isinstance(proto, bytes): - sp.LoadFromSerializedProto(proto) - else: - sp.load(proto) - - is_vision_model = ( - sp.PieceToId("") != sp.unk_id() - and sp.PieceToId("") != sp.unk_id() - and sp.PieceToId("") != sp.unk_id() - ) - - return cls(proto, is_vision_model=is_vision_model, **kwargs) + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_gemma3.py b/keras_hub/src/utils/transformers/convert_gemma3.py index 799252002e..4b7fb59bad 100644 --- a/keras_hub/src/utils/transformers/convert_gemma3.py +++ b/keras_hub/src/utils/transformers/convert_gemma3.py @@ -1,4 +1,5 @@ import numpy as np +from sentencepiece import SentencePieceProcessor from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( @@ -332,4 +333,17 @@ def transpose(x, shape): def convert_tokenizer(cls, preset, **kwargs): - return cls(get_file(preset, "tokenizer.model"), **kwargs) + proto = get_file(preset, "tokenizer.model") + sp = SentencePieceProcessor() + if isinstance(proto, bytes): + sp.LoadFromSerializedProto(proto) + else: + sp.load(proto) + + is_vision_model = ( + sp.PieceToId("") != sp.unk_id() + and sp.PieceToId("") != sp.unk_id() + and sp.PieceToId("") != sp.unk_id() + ) + + return cls(proto, is_vision_model=is_vision_model, **kwargs)