-
Notifications
You must be signed in to change notification settings - Fork 23
[WIP][models] Add support for Google's Gemma3 models #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
72e0bcf
25897ca
107488c
720f2dd
818212f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,335 @@ | ||
| # Gemma 3 Implementation - Changes Documentation | ||
|
|
||
| ## Overview | ||
| This document details all changes made to add Google Gemma 3 model support to GPULlama3.java. | ||
|
|
||
| **Date**: November 1, 2025 | ||
| **Model**: Google Gemma 3 (1B, 4B, 12B, 27B variants) | ||
| **Status**: Implementation complete, debugging in progress | ||
|
|
||
| --- | ||
|
|
||
| ## Architecture Details | ||
|
|
||
| ### Gemma 3 Unique Features | ||
| 1. **Sandwich Normalization**: 4 normalization layers per block (vs. 2 in standard transformers) | ||
| - `attn_norm` → Attention → `post_attention_norm` → Residual | ||
| - `ffn_norm` → FFN → `post_ffw_norm` → Residual | ||
|
|
||
| 2. **Q/K Normalization**: Per-head normalization of query and key vectors within attention | ||
|
|
||
| 3. **Embedding Scaling**: Embeddings multiplied by √dim for numerical stability | ||
|
|
||
| 4. **Byte-Level Tokenization**: First 256 tokens (type 3) are raw bytes, stored as `<unused0>` to `<unused255>` in vocabulary | ||
|
|
||
| 5. **SentencePiece Tokenizer**: Uses ▁ (U+2581) character to represent spaces | ||
|
|
||
| --- | ||
|
|
||
| ## Files Created | ||
|
|
||
| ### 1. Model Configuration | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/gemma3/Gemma3Configuration.java` | ||
| ```java | ||
| public record Gemma3Configuration( | ||
| int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, | ||
| int numberOfKeyValueHeads, int numberOfHeadsKey, int numberOfHeadsValue, | ||
| int vocabularySize, int contextLengthModel, int contextLength, | ||
| boolean sharedWeights, float rmsNormEps, float ropeTheta | ||
| ) implements Configuration | ||
| ``` | ||
| - Compatible with Qwen3 structure (includes numberOfHeadsKey/Value fields) | ||
| - Supports 128K context window | ||
|
|
||
| ### 2. Model State | ||
| **File**: `src/main/java/org/beehive/gpullama3/inference/state/Gemma3State.java` | ||
| - Manages KV cache and inference buffers | ||
| - Extends base `State` class | ||
|
|
||
| ### 3. Main Model Class | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/gemma3/Gemma3.java` | ||
| ```java | ||
| @Override | ||
| public void forward(State state, int token, int position) { | ||
| if (plan == null) { | ||
| InferenceCore.forwardJavaGemma3(this, state, token, position); | ||
| } else { | ||
| InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan()); | ||
| } | ||
| } | ||
| ``` | ||
| - Routes to Gemma3-specific CPU inference or Qwen3 GPU planner | ||
|
|
||
| ### 4. Tokenizer Implementation | ||
| **File**: `src/main/java/org/beehive/gpullama3/tokenizer/impl/Gemma3Tokenizer.java` | ||
|
|
||
| **Key Features**: | ||
| - Loads token types from metadata (`tokenizer.ggml.token_type`) | ||
| - Distinguishes between byte tokens (type 3) and regular tokens (type 6) | ||
| - Special token detection excludes `<unusedNN>` and `<0xHH>` patterns | ||
|
|
||
| **Critical Decoder Logic**: | ||
| ```java | ||
| @Override | ||
| public String decode(List<Integer> tokens) { | ||
| for (int token : tokens) { | ||
| // Type 3: Byte tokens (IDs 0-255) - decode as raw bytes | ||
| if (tokenTypes != null && tokenTypes[token] == 3) { | ||
| sb.append((char) token); | ||
| continue; | ||
| } | ||
|
|
||
| String tokenString = vocabulary.get(token); | ||
|
|
||
| // Hex byte tokens like <0x12> | ||
| if (tokenString.matches("<0x[0-9a-fA-F]{2}>")) { | ||
| String code = tokenString.substring(3, tokenString.length() - 1); | ||
| int byteValue = Integer.parseInt(code, 16); | ||
| tokenString = Character.toString(byteValue); | ||
| } else if (isSpecialToken(token)) { | ||
| continue; // Skip special tokens | ||
| } else { | ||
| // SentencePiece: ▁ → space | ||
| tokenString = tokenString.replace('▁', ' '); | ||
| } | ||
| sb.append(tokenString); | ||
| } | ||
| return sb.toString(); | ||
| } | ||
| ``` | ||
|
|
||
| ### 5. Chat Format | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/format/Gemma3ChatFormat.java` | ||
|
|
||
| **Template Format**: | ||
| ``` | ||
| <bos><start_of_turn>user | ||
| {user_message}<end_of_turn> | ||
| <start_of_turn>model | ||
| {model_message}<end_of_turn> | ||
| ``` | ||
|
|
||
| **Stop Tokens**: `<end_of_turn>`, `<eos>` | ||
|
|
||
| ### 6. Weight Classes | ||
|
|
||
| #### CPU Weights Base Class | ||
| **File**: `src/main/java/org/beehive/gpullama3/inference/weights/standard/StandardWeightsWithQKNorm.java` | ||
| ```java | ||
| public abstract class StandardWeightsWithQKNorm extends StandardWeights { | ||
| public final FloatTensor[] attnKNorm, attnQNorm; | ||
| } | ||
| ``` | ||
|
|
||
| #### Gemma3 CPU Weights | ||
| **File**: `src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma3StandardWeights.java` | ||
| ```java | ||
| public class Gemma3StandardWeights extends StandardWeightsWithQKNorm { | ||
| public final FloatTensor[] postAttentionNorm; // Post-attention normalization | ||
| public final FloatTensor[] postFFNNorm; // Post-FFN normalization | ||
| } | ||
| ``` | ||
|
|
||
| #### Gemma3 GPU Weights | ||
| **File**: `src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma3TornadoWeights.java` | ||
| ```java | ||
| public class Gemma3TornadoWeights extends FP16Weights { | ||
| public FloatArray[] rms_att_KNormLayered; | ||
| public FloatArray[] rms_att_QNormLayered; | ||
| public FloatArray[] postAttentionNormLayered; | ||
| public FloatArray[] postFFNNormLayered; | ||
| } | ||
| ``` | ||
|
|
||
| ### 7. Model Loader | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/loader/Gemma3ModelLoader.java` | ||
|
|
||
| **Metadata Prefix Detection**: | ||
| ```java | ||
| // Tries: gemma3. → gemma2. → gemma. → llama. | ||
| if (metadata.containsKey("gemma3.embedding_length")) { | ||
| prefix = "gemma3."; | ||
| } else if (metadata.containsKey("gemma2.embedding_length")) { | ||
| prefix = "gemma2."; | ||
| } | ||
| ``` | ||
|
|
||
| **Tensor Loading** (4 norm layers per block): | ||
| ```java | ||
| loadArrayOfQuantized(config.numberOfLayers(), | ||
| i -> tensorEntries.get("blk." + i + ".attn_norm.weight")) | ||
| loadArrayOfQuantized(config.numberOfLayers(), | ||
| i -> tensorEntries.get("blk." + i + ".post_attention_norm.weight")) | ||
| loadArrayOfQuantized(config.numberOfLayers(), | ||
| i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")) | ||
| loadArrayOfQuantized(config.numberOfLayers(), | ||
| i -> tensorEntries.get("blk." + i + ".post_ffw_norm.weight")) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Files Modified | ||
|
|
||
| ### 1. Model Type Enum | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/ModelType.java` | ||
|
|
||
| **Added**: | ||
| ```java | ||
| GEMMA_3 { | ||
| @Override | ||
| public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, | ||
| boolean loadWeights, boolean useTornadovm) { | ||
| return new Gemma3ModelLoader(fileChannel, gguf, contextLength, | ||
| loadWeights, useTornadovm).loadModel(); | ||
| } | ||
| } | ||
| ``` | ||
|
|
||
| ### 2. Model Detection | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java` | ||
|
|
||
| **Added**: | ||
| ```java | ||
| else if (lowerName.contains("gemma")) { | ||
| return ModelType.GEMMA_3; | ||
| } | ||
| ``` | ||
|
|
||
| ### 3. Inference Core | ||
| **File**: `src/main/java/org/beehive/gpullama3/inference/InferenceCore.java` | ||
|
|
||
| **Added Method**: `forwardJavaGemma3()` (~150 lines) | ||
|
|
||
| **Key Implementation Details**: | ||
| ```java | ||
| // Embedding scaling | ||
| float embeddingScale = (float) Math.sqrt(dim); | ||
| for (int i = 0; i < dim; i++) { | ||
| state.x.setFloat(i, state.x.getFloat(i) * embeddingScale); | ||
| } | ||
|
|
||
| for (int l = 0; l < config.numberOfLayers(); l++) { | ||
| // ATTENTION BLOCK with sandwich normalization | ||
| state.x.copyTo(0, state.xb2, 0, dim); // Save residual | ||
| rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], ...); | ||
|
|
||
| // ... QKV matmuls, Q/K norm, RoPE, attention ... | ||
|
|
||
| weights.wo[l].matmul(state.xb, state.x, ...); | ||
| rmsnorm(state.x, state.x, weights.postAttentionNorm[curLayer], ...); // POST-NORM | ||
| state.x.addInPlace(state.xb2); // Residual | ||
|
|
||
| // FFN BLOCK with sandwich normalization | ||
| state.x.copyTo(0, state.xb2, 0, dim); // Save residual | ||
| rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], ...); | ||
|
|
||
| // ... FFN computation ... | ||
|
|
||
| rmsnorm(state.x, state.x, weights.postFFNNorm[curLayer], ...); // POST-NORM | ||
| state.x.addInPlace(state.xb2); // Residual | ||
| } | ||
| ``` | ||
|
|
||
| ### 4. TornadoVM Planner | ||
| **File**: `src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java` | ||
|
|
||
| **Modified**: | ||
| ```java | ||
| case QWEN_3, GEMMA_3 -> createQWEN3Planner(state, model); | ||
| ``` | ||
| Routes Gemma 3 to Qwen3 planner (both use Q/K normalization) | ||
|
|
||
| ### 5. Configuration Interface | ||
| **File**: `src/main/java/org/beehive/gpullama3/model/Configuration.java` | ||
|
|
||
| **Added**: | ||
| ```java | ||
| int numberOfHeadsValue(); // For Gemma3/Qwen3 compatibility | ||
| ``` | ||
|
|
||
| ### 6. Other Configuration Classes | ||
| **Files**: | ||
| - `LlamaConfiguration.java` | ||
| - `MistralConfiguration.java` | ||
| - `Phi3Configuration.java` | ||
|
|
||
| **Added** implementations of `numberOfHeadsValue()` method | ||
|
|
||
| --- | ||
|
|
||
| ## Known Issues | ||
|
|
||
| ### Issue 1: Immediate Stop Token Generation | ||
| **Symptom**: Model generates `<end_of_turn>` (token 106) as first token | ||
| **Status**: Under investigation | ||
| **Possible Causes**: | ||
| 1. Incorrect normalization implementation | ||
| 2. Missing Gemma-specific initialization | ||
| 3. Weight loading mismatch | ||
| 4. Chat template formatting issue | ||
|
|
||
| ### Issue 2: GGUF Compatibility | ||
| **Tested Models**: | ||
| - ❌ User-provided GGUF files (corrupted vocabulary) | ||
| - ❌ `ggml-org/gemma-3-4b-it-GGUF` (same stop token issue) | ||
|
|
||
| **Next Steps**: | ||
| - Debug embedding scaling factor | ||
| - Verify RMSNorm epsilon values | ||
| - Check attention mask implementation | ||
| - Compare with llama.cpp implementation | ||
|
|
||
| --- | ||
|
|
||
| ## Testing | ||
|
|
||
| ### Test Command | ||
| ```bash | ||
| ./llama-tornado --model gemma-3-4b-it-Q8_0.gguf --prompt "Tell me a joke" | ||
| ``` | ||
|
|
||
| ### Expected Output Format | ||
| ``` | ||
| <bos><start_of_turn>user | ||
| Tell me a joke<end_of_turn> | ||
| <start_of_turn>model | ||
| [Model response]<end_of_turn> | ||
| ``` | ||
|
|
||
| ### Performance | ||
| - **CPU**: ~6-9 tok/s on FP16/Q8_0 (4B model) | ||
| - **GPU**: Not yet tested | ||
|
|
||
| --- | ||
|
|
||
| ## References | ||
|
|
||
| 1. **Gemma 3 Architecture**: https://github.com/ggml-org/llama.cpp/blob/master/docs/multimodal/gemma3.md | ||
| 2. **HuggingFace Model**: https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF | ||
| 3. **Google Blog**: Gemma 3 uses sandwich normalization and Q/K norm | ||
| 4. **SentencePiece Tokenizer**: Byte-level encoding with space as ▁ character | ||
|
|
||
| --- | ||
|
|
||
| ## Build and Run | ||
|
|
||
| ### Compile | ||
| ```bash | ||
| make | ||
| ``` | ||
|
|
||
| ### Run CPU Inference | ||
| ```bash | ||
| ./llama-tornado --model gemma-3-4b-it-Q8_0.gguf --prompt "Hello" | ||
| ``` | ||
|
|
||
| ### Run GPU Inference (TornadoVM) | ||
| ```bash | ||
| ./llama-tornado --model gemma-3-4b-it-Q8_0.gguf --prompt "Hello" --gpu --gpu-memory 8GB | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Contributors | ||
| - Initial implementation: Claude (Anthropic) | ||
| - Architecture research: Based on llama.cpp and Graphcore blog posts |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,45 @@ | ||||||
| import struct | ||||||
| import sys | ||||||
|
|
||||||
| def read_gguf_metadata(filepath): | ||||||
| with open(filepath, 'rb') as f: | ||||||
| # Read header | ||||||
| magic = f.read(4) | ||||||
| if magic != b'GGUF': | ||||||
| print('Not a GGUF file') | ||||||
| return | ||||||
|
|
||||||
| version = struct.unpack('<I', f.read(4))[0] | ||||||
| tensor_count = struct.unpack('<Q', f.read(8))[0] | ||||||
| metadata_kv_count = struct.unpack('<Q', f.read(8))[0] | ||||||
|
|
||||||
| print(f'GGUF version: {version}') | ||||||
| print(f'Metadata entries: {metadata_kv_count}') | ||||||
| print('') | ||||||
|
|
||||||
| # Read metadata key-value pairs | ||||||
| for i in range(metadata_kv_count): | ||||||
| # Read key length and key | ||||||
| key_len = struct.unpack('<Q', f.read(8))[0] | ||||||
| key = f.read(key_len).decode('utf-8') | ||||||
|
|
||||||
| # Read value type | ||||||
| value_type = struct.unpack('<I', f.read(4))[0] | ||||||
|
|
||||||
| # Read value based on type | ||||||
| if value_type == 8: # STRING | ||||||
| str_len = struct.unpack('<Q', f.read(8))[0] | ||||||
| try: | ||||||
| value = f.read(str_len).decode('utf-8') | ||||||
| if 'name' in key.lower() or 'model' in key.lower() or 'arch' in key.lower(): | ||||||
| print(f'{key}: {value}') | ||||||
| except: | ||||||
|
||||||
| except: | |
| except UnicodeDecodeError: |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'import sys' statement is duplicated on line 43, when it was already imported on line 2. Remove the duplicate import.
| import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable tensor_count is not used.