-
Notifications
You must be signed in to change notification settings - Fork 307
Gemma3 text keras hf checkpoint conversion #2433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Gemma3 text keras hf checkpoint conversion #2433
Conversation
Summary of ChangesHello @kharshith-k, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Gemma3 checkpoint conversion tool by introducing the capability to export Keras models into the Hugging Face Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a valuable feature to convert Keras Gemma3 models to the Hugging Face Safetensors format, enhancing interoperability. The implementation is comprehensive, covering configuration conversion, weight porting, and a validation step. I've provided a few suggestions to improve code clarity, maintainability, and adherence to the repository's style guide, primarily by improving docstrings, refactoring duplicated code, and ensuring deterministic validation.
|
Thanks for the PR, the export to safetensors should be made available here https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/utils/transformers/export.
|
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces functionality to export Gemma3 models from KerasHub to the Hugging Face format. It adds the necessary export logic, corresponding tests, and integrates this into the checkpoint conversion script. The overall approach is good, with solid testing. However, there's a critical issue in tools/checkpoint_conversion/convert_gemma3_checkpoints.py where the export logic is duplicated instead of reusing the newly added library functions. This violates the DRY principle and the repository's style guide on backend-agnostic code. Additionally, there are some areas for improvement in the core export logic in keras_hub/src/utils/transformers/export/gemma3.py concerning code duplication and incorrect fallback logic for normalization layers.
| def convert_to_hf_config(keras_config): | ||
| """Convert Keras Gemma config to Hugging Face GemmaConfig. | ||
| Args: | ||
| keras_config: A Keras Gemma3 config object from the backbone. | ||
| Returns: | ||
| A `transformers.Gemma3TextConfig` instance. | ||
| """ | ||
| hf_config = transformers.Gemma3TextConfig( | ||
| vocab_size=keras_config.vocabulary_size, | ||
| num_hidden_layers=keras_config.num_layers, | ||
| num_attention_heads=keras_config.num_query_heads, | ||
| num_key_value_heads=keras_config.num_key_value_heads, | ||
| hidden_size=keras_config.hidden_dim, | ||
| intermediate_size=keras_config.intermediate_dim, | ||
| head_dim=keras_config.head_dim, | ||
| max_position_embeddings=32768, | ||
| ) | ||
| return hf_config | ||
|
|
||
|
|
||
| def export_to_hf(backbone, keras_tokenizer, path): | ||
| """Convert a Keras Gemma model to Hugging Face format and save to path. | ||
| Args: | ||
| backbone: A `keras_hub.models.Gemma3Backbone` instance. | ||
| keras_tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance. | ||
| path: str. The path to save the Hugging Face model to. | ||
| """ | ||
| hf_config = convert_to_hf_config(backbone) | ||
| weights_dict = {} | ||
|
|
||
| # Helper function to convert bfloat16 weights to torch tensors | ||
| def to_torch(weight): | ||
| # Convert bfloat16 to float32 first, then to torch, then to bfloat16 | ||
| if hasattr(weight.dtype, "name") and "bfloat16" in str(weight.dtype): | ||
| weight = np.array(weight, dtype=np.float32) | ||
| return torch.from_numpy(weight).to(torch.bfloat16) | ||
|
|
||
| # Token embeddings | ||
| token_embedding = backbone.get_layer("token_embedding").get_weights()[0] | ||
| weights_dict["model.embed_tokens.weight"] = to_torch(token_embedding) | ||
|
|
||
| for i in range(backbone.num_layers): | ||
| block = backbone.get_layer(f"decoder_block_{i}") | ||
| q_kernel = block.attention.query_dense.get_weights()[0] | ||
| q_kernel = ( | ||
| torch.from_numpy(np.array(q_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .permute(1, 0, 2) | ||
| .reshape(backbone.hidden_dim, -1) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_kernel | ||
|
|
||
| k_kernel = block.attention.key_dense.get_weights()[0] | ||
| k_kernel = ( | ||
| torch.from_numpy(np.array(k_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .permute(1, 0, 2) | ||
| .reshape(backbone.hidden_dim, -1) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_kernel | ||
|
|
||
| v_kernel = block.attention.value_dense.get_weights()[0] | ||
| v_kernel = ( | ||
| torch.from_numpy(np.array(v_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .permute(1, 0, 2) | ||
| .reshape(backbone.hidden_dim, -1) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_kernel | ||
|
|
||
| o_kernel = block.attention.output_dense.get_weights()[0] | ||
| o_kernel = ( | ||
| torch.from_numpy(np.array(o_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .permute(2, 0, 1) | ||
| .reshape(backbone.hidden_dim, -1) | ||
| ) | ||
| weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = o_kernel | ||
|
|
||
| q_norm = block.attention.query_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.self_attn.q_norm.weight"] = to_torch( | ||
| q_norm | ||
| ) | ||
|
|
||
| k_norm = block.attention.key_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.self_attn.k_norm.weight"] = to_torch( | ||
| k_norm | ||
| ) | ||
|
|
||
| gate_kernel = block.gating_ffw.get_weights()[0] | ||
| gate_kernel = ( | ||
| torch.from_numpy(np.array(gate_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||
|
|
||
| up_kernel = block.gating_ffw_2.get_weights()[0] | ||
| up_kernel = ( | ||
| torch.from_numpy(np.array(up_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel | ||
|
|
||
| down_kernel = block.ffw_linear.get_weights()[0] | ||
| down_kernel = ( | ||
| torch.from_numpy(np.array(down_kernel, dtype=np.float32)) | ||
| .to(torch.bfloat16) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel | ||
|
|
||
| input_layer_norm = block.pre_attention_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.input_layernorm.weight"] = to_torch( | ||
| input_layer_norm | ||
| ) | ||
|
|
||
| post_attn_norm = block.post_attention_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( | ||
| to_torch(post_attn_norm) | ||
| ) | ||
|
|
||
| pre_feedforward_layernorm_weight = block.pre_ffw_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.pre_feedforward_layernorm.weight"] = ( | ||
| to_torch(pre_feedforward_layernorm_weight) | ||
| ) | ||
|
|
||
| post_feedforward_layernorm_weight = block.post_ffw_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.post_feedforward_layernorm.weight"] = ( | ||
| to_torch(post_feedforward_layernorm_weight) | ||
| ) | ||
|
|
||
| final_norm = backbone.get_layer("final_normalization").get_weights()[0] | ||
| weights_dict["model.norm.weight"] = to_torch(final_norm) | ||
| weights_dict["lm_head.weight"] = weights_dict[ | ||
| "model.embed_tokens.weight" | ||
| ].clone() | ||
|
|
||
| os.makedirs(path, exist_ok=True) | ||
| with open(os.path.join(path, "config.json"), "w") as f: | ||
| json.dump(hf_config.to_dict(), f) | ||
| weights_dict = {k: v.contiguous() for k, v in weights_dict.items()} | ||
| save_file(weights_dict, os.path.join(path, "model.safetensors")) | ||
| keras_tokenizer.save_assets(path) | ||
| vocab_spm = os.path.join(path, "vocabulary.spm") | ||
| tokenizer_model = os.path.join(path, "tokenizer.model") | ||
| if os.path.exists(vocab_spm): | ||
| shutil.move(vocab_spm, tokenizer_model) | ||
| print("Export complete! Files saved in:", path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions convert_to_hf_config and export_to_hf duplicate the Hugging Face export logic that is already being added in keras_hub/src/utils/transformers/export/. This introduces significant code duplication and makes future maintenance difficult.
This implementation also uses torch and numpy directly for tensor manipulations, which violates the repository's style guide principle of being backend-agnostic.1
Please remove these duplicated functions and instead use the export_to_transformers method available on the Keras model. The logic in the main function at line 780 should be updated to call this method. For example:
# In main()
# ...
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor(tokenizer=keras_tokenizer)
causal_lm = keras_hub.models.Gemma3CausalLM(
backbone=keras_model,
preprocessor=preprocessor,
)
causal_lm.export_to_transformers(export_dir)
# ...Style Guide References
Footnotes
-
All code must be backend agnostic. The duplicated code uses torch-specific operations, violating this principle. ↩
Here's the Colab Gist link for Gemma3-Text models conversion from keras to safetensors format