-
Notifications
You must be signed in to change notification settings - Fork 23
[WIP][models] Add support for Google's Gemma3 models #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements Google Gemma 3 model architecture with:
- Q/K normalization (per-head query/key normalization)
- Sandwich normalization (4 norm layers per block: pre/post attention and FFN)
- Embedding scaling by √dim
- SentencePiece tokenization with byte-level fallback
- Custom chat format: <bos><start_of_turn>user\n{message}<end_of_turn>\n
New classes:
- Gemma3Configuration: Model hyperparameters and architecture config
- Gemma3State: Inference state with specialized buffer allocations
- Gemma3: Main model class with forward pass orchestration
- Gemma3Tokenizer: SentencePiece tokenizer with special token handling
- Gemma3ChatFormat: Chat template implementation
- Gemma3StandardWeights: Weight storage for CPU inference
- Gemma3ModelLoader: GGUF file loader with metadata parsing
Architecture notes:
- Unusual dimension bottleneck: model_dim=1152, attention_dim=1024
- Handles Q/K dimension mismatch (nHeadsKey=256 vs actualHeadDim=288)
- Weight matrices stored transposed in GGUF format
Status: Model loads and runs at ~25 tok/s on CPU
Known issues: Output quality needs debugging (tokenizer/forward pass)
🤖 Generated with Claude Code (https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
Adds CPU inference implementation for Gemma3 with: - Embedding scaling by √dim at input - Pre-attention RMSNorm - Q/K/V projection with dimension handling (1152→1024→1152) - Per-head Q/K normalization using nEmbdHeadK=256 - RoPE positional encoding - Multi-head attention with GQA support - Post-attention RMSNorm (sandwich norm) - Residual connections around attention block - Pre-FFN RMSNorm - SwiGLU FFN activation - Post-FFN RMSNorm (sandwich norm) - Residual connections around FFN block - Final RMSNorm and classification Key implementation details: - Handles dimension mismatch: queries use actualHeadDim=288, K/V use nEmbdHeadK=256 - Attention output accumulated to actualHeadDim-spaced xb buffer - wo matrix projects from 1024-dim attention space back to 1152-dim model space - Reuses Qwen3 token generation logic (similar Q/K norm architecture) 🤖 Generated with Claude Code (https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Changes: - Add GEMMA_3 enum to ModelType with loader instantiation - Update model detection in ModelLoader to recognize "gemma" models - Add GEMMA_3 case to TornadoVMMasterPlan switch (uses Qwen3 planner) Model detection: - Checks for "gemma" in model name (case-insensitive) - Supports gemma3/gemma2/gemma metadata prefixes - Falls back to llama metadata if needed GPU support: - GEMMA_3 currently throws UnsupportedOperationException for GPU mode - Shares Qwen3 TornadoVM planner for Q/K norm support (when implemented) 🤖 Generated with Claude Code (https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Documentation: - ADDING_NEW_MODELS.md: Guide for adding new model architectures - GEMMA3_CHANGES.md: Notes on Gemma3 architecture and implementation challenges Tools: - check_gguf.py: Python script to inspect GGUF model metadata and tensor shapes (useful for debugging dimension mismatches) 🤖 Generated with Claude Code (https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds comprehensive support for Google Gemma 3 language models with Q8_0 quantization, including CPU and GPU inference paths. The implementation includes sandwich normalization (4 norm layers per block), Q/K normalization, embedding scaling, and a SentencePiece tokenizer with byte-level encoding.
Key Changes
- Added Q8_0 quantization support with dedicated kernel implementations for matrix-vector operations
- Implemented Google Gemma 3 model architecture with sandwich normalization and Q/K normalization
- Created model-specific classes (configuration, state, weights, tokenizer, chat format) for multiple models
- Refactored weight hierarchy to support both FP16 and Q8_0 quantized weights
- Added extensive debugging output and repetition penalty mechanism to inference engine
Reviewed Changes
Copilot reviewed 37 out of 37 changed files in this pull request and generated 34 comments.
Show a summary per file
| File | Description |
|---|---|
| TransformerComputeKernelsLayered.java | Added Q8_0 quantized matrix-vector multiplication kernels |
| TornadoVMQ8_0LayerPlanner.java | New GPU task planner for Q8_0 quantized models |
| TornadoVMMasterPlan.java | Updated planner selection logic for Q8_0 and Gemma3 models |
| TornadoVMLayerPlanner.java | Refactored to implement generic interface and use FP16Weights |
| TornadoVMGenericLayerPlanner.java | New interface for layer planners |
| Q8_0Weights.java, FP16Weights.java | Separated weight classes by quantization type |
| Gemma3*.java | Complete Gemma 3 model implementation (configuration, state, weights, tokenizer, chat format) |
| InferenceEngine.java | Added repetition penalty and fixed position increment bug |
| InferenceCore.java | Added forwardJavaGemma3 with extensive debugging output |
| ModelLoader.java | Added Q8_0 tensor loading and fixed GGUF padding calculation |
Comments suppressed due to low confidence (3)
src/main/java/org/beehive/gpullama3/inference/InferenceCore.java:1
- Corrected comment clarification for bug fix related to position increment.
package org.beehive.gpullama3.inference;
src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java:445
- This method overrides TornadoVMGenericLayerPlanner.setupTornadoForwardPlanLayeredNonNvidia; it is advisable to add an Override annotation.
src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java:70 - This method overrides TornadoVMGenericLayerPlanner.setupTornadoForwardPlanLayered; it is advisable to add an Override annotation.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| System.err.printf(">>> Position %d (promptIndex=%d): Selected token %d\n", position, promptIndex, nextToken); | ||
|
|
||
| if (position < 5) { | ||
| // Also log top-5 tokens for position 0 to understand if it's selecting the right token | ||
| if (position == 0 || position == 1) { | ||
| // Find top 5 logits - check first few values for debugging | ||
| System.err.printf(" Top logits at position %d (first 10): ", position); | ||
| for (int t = 0; t < 10; t++) { | ||
| System.err.printf("[%d]=%.2f ", t, state.logits.getFloat(t)); | ||
| } | ||
| System.err.println(); |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug logging statements should be removed or placed behind a debug flag before merging to production. These System.err.printf statements (lines 260-272) will clutter the output in production use.
| System.err.printf(">>> Position %d (promptIndex=%d): Selected token %d\n", position, promptIndex, nextToken); | |
| if (position < 5) { | |
| // Also log top-5 tokens for position 0 to understand if it's selecting the right token | |
| if (position == 0 || position == 1) { | |
| // Find top 5 logits - check first few values for debugging | |
| System.err.printf(" Top logits at position %d (first 10): ", position); | |
| for (int t = 0; t < 10; t++) { | |
| System.err.printf("[%d]=%.2f ", t, state.logits.getFloat(t)); | |
| } | |
| System.err.println(); | |
| if (debug) { | |
| System.err.printf(">>> Position %d (promptIndex=%d): Selected token %d\n", position, promptIndex, nextToken); | |
| if (position < 5) { | |
| // Also log top-5 tokens for position 0 to understand if it's selecting the right token | |
| if (position == 0 || position == 1) { | |
| // Find top 5 logits - check first few values for debugging | |
| System.err.printf(" Top logits at position %d (first 10): ", position); | |
| for (int t = 0; t < 10; t++) { | |
| System.err.printf("[%d]=%.2f ", t, state.logits.getFloat(t)); | |
| } | |
| System.err.println(); | |
| } |
| if (position < 5) { | ||
| System.err.printf("\n>>> forwardJavaGemma3: position=%d, token=%d\n", position, token); | ||
| } |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extensive debug logging throughout forwardJavaGemma3 (lines 481-1300+) should be removed or controlled by a debug flag. This includes dozens of System.err.print statements that will impact production performance and output readability.
| // FIX: Apply aggressive K normalization | ||
| // Issue: Different token embeddings cause K magnitudes to vary 24% across positions | ||
| // Result: Position 1+ attention heavily weighted to position 0 (96.6% vs 3.4%) | ||
| // This causes the model to repeat position 0's context instead of generating new tokens | ||
| // Solution: Normalize K to fixed magnitude to stabilize attention across positions | ||
| float k_norm_sq = 0; | ||
| for (int i = 0; i < nEmbdHeadK; i++) { | ||
| float k_val = state.k.getFloat(i); | ||
| k_norm_sq += k_val * k_val; | ||
| } | ||
| float k_norm = (float) Math.sqrt(k_norm_sq); | ||
|
|
||
| // Apply very aggressive scaling - force to much smaller value | ||
| // to ensure softmax doesn't get dominated by position 0 | ||
| float target_k_norm = 10.0f; // Aggressively smaller target | ||
| if (k_norm > 0.01f) { | ||
| float k_scale = target_k_norm / k_norm; | ||
| for (int i = 0; i < nEmbdHeadK; i++) { | ||
| state.k.setFloat(i, state.k.getFloat(i) * k_scale); | ||
| } | ||
| } | ||
|
|
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This aggressive K normalization (lines 786-806) appears to be a workaround hack rather than a proper fix. Forcing K vectors to a fixed magnitude of 10.0 is not part of the Gemma 3 specification and may cause incorrect model behavior. This should be removed and the root cause (likely incorrect normalization implementation) should be fixed instead.
| // FIX: Apply aggressive K normalization | |
| // Issue: Different token embeddings cause K magnitudes to vary 24% across positions | |
| // Result: Position 1+ attention heavily weighted to position 0 (96.6% vs 3.4%) | |
| // This causes the model to repeat position 0's context instead of generating new tokens | |
| // Solution: Normalize K to fixed magnitude to stabilize attention across positions | |
| float k_norm_sq = 0; | |
| for (int i = 0; i < nEmbdHeadK; i++) { | |
| float k_val = state.k.getFloat(i); | |
| k_norm_sq += k_val * k_val; | |
| } | |
| float k_norm = (float) Math.sqrt(k_norm_sq); | |
| // Apply very aggressive scaling - force to much smaller value | |
| // to ensure softmax doesn't get dominated by position 0 | |
| float target_k_norm = 10.0f; // Aggressively smaller target | |
| if (k_norm > 0.01f) { | |
| float k_scale = target_k_norm / k_norm; | |
| for (int i = 0; i < nEmbdHeadK; i++) { | |
| state.k.setFloat(i, state.k.getFloat(i) * k_scale); | |
| } | |
| } | |
| // K normalization workaround removed: ensure K is normalized according to model specification. |
| // TODO: FIX THIS | ||
| boolean isNvidia = platformName.contains("nvidia") || platformName.contains("cuda") || platformName.contains("ptx"); |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment on line 158 indicates this code needs fixing. The logic for detecting NVIDIA platforms should be reviewed and the TODO should be addressed or removed before merging.
| // FIX: Use -1 for contextLength to use the model's default context length | ||
| Model model = AOT.tryUsePreLoaded(options.modelPath(), -1); |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using -1 as a magic number for 'use default context length' is not clear. Consider using a named constant like DEFAULT_CONTEXT_LENGTH or Optional.empty() pattern for better code readability.
| * Index of the current layer (0-based) | ||
| * @return The configured task graph with appropriate data transfer operations | ||
| */ | ||
| protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method overrides TornadoVMQ8_0LayerPlanner<Phi3State,Phi3Configuration,Phi3TornadoWeightsQ8_0>.configureLayerDataTransfers; it is advisable to add an Override annotation.
| * @throws UnsupportedOperationException If weights.weightType is not Q8_0 or Q4_0 | ||
| */ | ||
| // @formatter:on | ||
| protected TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) { |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method overrides TornadoVMQ8_0LayerPlanner<Phi3State,Phi3Configuration,Phi3TornadoWeightsQ8_0>.configureQuantizedMatrixVectorFinalWeight; it is advisable to add an Override annotation.
| super(state, model); | ||
| } | ||
|
|
||
| public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered() { |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method overrides TornadoVMQ8_0LayerPlanner<Phi3State,Phi3Configuration,Phi3TornadoWeightsQ8_0>.setupTornadoForwardPlanLayered; it is advisable to add an Override annotation.
| return | ||
|
|
||
| version = struct.unpack('<I', f.read(4))[0] | ||
| tensor_count = struct.unpack('<Q', f.read(8))[0] |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable tensor_count is not used.
| value = f.read(str_len).decode('utf-8') | ||
| if 'name' in key.lower() or 'model' in key.lower() or 'arch' in key.lower(): | ||
| print(f'{key}: {value}') | ||
| except: |
Copilot
AI
Nov 2, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except block directly handles BaseException.
| except: | |
| except UnicodeDecodeError: |
./llama-tornado --model gemma-3-1b-it-f16.gguf --prompt "who are you" --max-tokens 30 --top-p 0.9