Skip to content

Conversation

@mikepapadim
Copy link
Member

@mikepapadim mikepapadim commented Nov 2, 2025

./llama-tornado --model gemma-3-1b-it-f16.gguf --prompt "who are you" --max-tokens 30 --top-p 0.9

mikepapadim and others added 5 commits November 1, 2025 11:12
Implements Google Gemma 3 model architecture with:
- Q/K normalization (per-head query/key normalization)
- Sandwich normalization (4 norm layers per block: pre/post attention and FFN)
- Embedding scaling by √dim
- SentencePiece tokenization with byte-level fallback
- Custom chat format: <bos><start_of_turn>user\n{message}<end_of_turn>\n

New classes:
- Gemma3Configuration: Model hyperparameters and architecture config
- Gemma3State: Inference state with specialized buffer allocations
- Gemma3: Main model class with forward pass orchestration
- Gemma3Tokenizer: SentencePiece tokenizer with special token handling
- Gemma3ChatFormat: Chat template implementation
- Gemma3StandardWeights: Weight storage for CPU inference
- Gemma3ModelLoader: GGUF file loader with metadata parsing

Architecture notes:
- Unusual dimension bottleneck: model_dim=1152, attention_dim=1024
- Handles Q/K dimension mismatch (nHeadsKey=256 vs actualHeadDim=288)
- Weight matrices stored transposed in GGUF format

Status: Model loads and runs at ~25 tok/s on CPU
Known issues: Output quality needs debugging (tokenizer/forward pass)

🤖 Generated with Claude Code (https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Adds CPU inference implementation for Gemma3 with:
- Embedding scaling by √dim at input
- Pre-attention RMSNorm
- Q/K/V projection with dimension handling (1152→1024→1152)
- Per-head Q/K normalization using nEmbdHeadK=256
- RoPE positional encoding
- Multi-head attention with GQA support
- Post-attention RMSNorm (sandwich norm)
- Residual connections around attention block
- Pre-FFN RMSNorm
- SwiGLU FFN activation
- Post-FFN RMSNorm (sandwich norm)
- Residual connections around FFN block
- Final RMSNorm and classification

Key implementation details:
- Handles dimension mismatch: queries use actualHeadDim=288, K/V use nEmbdHeadK=256
- Attention output accumulated to actualHeadDim-spaced xb buffer
- wo matrix projects from 1024-dim attention space back to 1152-dim model space
- Reuses Qwen3 token generation logic (similar Q/K norm architecture)

🤖 Generated with Claude Code (https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Changes:
- Add GEMMA_3 enum to ModelType with loader instantiation
- Update model detection in ModelLoader to recognize "gemma" models
- Add GEMMA_3 case to TornadoVMMasterPlan switch (uses Qwen3 planner)

Model detection:
- Checks for "gemma" in model name (case-insensitive)
- Supports gemma3/gemma2/gemma metadata prefixes
- Falls back to llama metadata if needed

GPU support:
- GEMMA_3 currently throws UnsupportedOperationException for GPU mode
- Shares Qwen3 TornadoVM planner for Q/K norm support (when implemented)

🤖 Generated with Claude Code (https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Documentation:
- ADDING_NEW_MODELS.md: Guide for adding new model architectures
- GEMMA3_CHANGES.md: Notes on Gemma3 architecture and implementation challenges

Tools:
- check_gguf.py: Python script to inspect GGUF model metadata and tensor shapes
  (useful for debugging dimension mismatches)

🤖 Generated with Claude Code (https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds comprehensive support for Google Gemma 3 language models with Q8_0 quantization, including CPU and GPU inference paths. The implementation includes sandwich normalization (4 norm layers per block), Q/K normalization, embedding scaling, and a SentencePiece tokenizer with byte-level encoding.

Key Changes

  • Added Q8_0 quantization support with dedicated kernel implementations for matrix-vector operations
  • Implemented Google Gemma 3 model architecture with sandwich normalization and Q/K normalization
  • Created model-specific classes (configuration, state, weights, tokenizer, chat format) for multiple models
  • Refactored weight hierarchy to support both FP16 and Q8_0 quantized weights
  • Added extensive debugging output and repetition penalty mechanism to inference engine

Reviewed Changes

Copilot reviewed 37 out of 37 changed files in this pull request and generated 34 comments.

Show a summary per file
File Description
TransformerComputeKernelsLayered.java Added Q8_0 quantized matrix-vector multiplication kernels
TornadoVMQ8_0LayerPlanner.java New GPU task planner for Q8_0 quantized models
TornadoVMMasterPlan.java Updated planner selection logic for Q8_0 and Gemma3 models
TornadoVMLayerPlanner.java Refactored to implement generic interface and use FP16Weights
TornadoVMGenericLayerPlanner.java New interface for layer planners
Q8_0Weights.java, FP16Weights.java Separated weight classes by quantization type
Gemma3*.java Complete Gemma 3 model implementation (configuration, state, weights, tokenizer, chat format)
InferenceEngine.java Added repetition penalty and fixed position increment bug
InferenceCore.java Added forwardJavaGemma3 with extensive debugging output
ModelLoader.java Added Q8_0 tensor loading and fixed GGUF padding calculation
Comments suppressed due to low confidence (3)

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java:1

  • Corrected comment clarification for bug fix related to position increment.
package org.beehive.gpullama3.inference;

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java:445


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +260 to +270
System.err.printf(">>> Position %d (promptIndex=%d): Selected token %d\n", position, promptIndex, nextToken);

if (position < 5) {
// Also log top-5 tokens for position 0 to understand if it's selecting the right token
if (position == 0 || position == 1) {
// Find top 5 logits - check first few values for debugging
System.err.printf(" Top logits at position %d (first 10): ", position);
for (int t = 0; t < 10; t++) {
System.err.printf("[%d]=%.2f ", t, state.logits.getFloat(t));
}
System.err.println();
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug logging statements should be removed or placed behind a debug flag before merging to production. These System.err.printf statements (lines 260-272) will clutter the output in production use.

Suggested change
System.err.printf(">>> Position %d (promptIndex=%d): Selected token %d\n", position, promptIndex, nextToken);
if (position < 5) {
// Also log top-5 tokens for position 0 to understand if it's selecting the right token
if (position == 0 || position == 1) {
// Find top 5 logits - check first few values for debugging
System.err.printf(" Top logits at position %d (first 10): ", position);
for (int t = 0; t < 10; t++) {
System.err.printf("[%d]=%.2f ", t, state.logits.getFloat(t));
}
System.err.println();
if (debug) {
System.err.printf(">>> Position %d (promptIndex=%d): Selected token %d\n", position, promptIndex, nextToken);
if (position < 5) {
// Also log top-5 tokens for position 0 to understand if it's selecting the right token
if (position == 0 || position == 1) {
// Find top 5 logits - check first few values for debugging
System.err.printf(" Top logits at position %d (first 10): ", position);
for (int t = 0; t < 10; t++) {
System.err.printf("[%d]=%.2f ", t, state.logits.getFloat(t));
}
System.err.println();
}

Copilot uses AI. Check for mistakes.
Comment on lines +482 to +484
if (position < 5) {
System.err.printf("\n>>> forwardJavaGemma3: position=%d, token=%d\n", position, token);
}
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extensive debug logging throughout forwardJavaGemma3 (lines 481-1300+) should be removed or controlled by a debug flag. This includes dozens of System.err.print statements that will impact production performance and output readability.

Copilot uses AI. Check for mistakes.
Comment on lines +785 to +806
// FIX: Apply aggressive K normalization
// Issue: Different token embeddings cause K magnitudes to vary 24% across positions
// Result: Position 1+ attention heavily weighted to position 0 (96.6% vs 3.4%)
// This causes the model to repeat position 0's context instead of generating new tokens
// Solution: Normalize K to fixed magnitude to stabilize attention across positions
float k_norm_sq = 0;
for (int i = 0; i < nEmbdHeadK; i++) {
float k_val = state.k.getFloat(i);
k_norm_sq += k_val * k_val;
}
float k_norm = (float) Math.sqrt(k_norm_sq);

// Apply very aggressive scaling - force to much smaller value
// to ensure softmax doesn't get dominated by position 0
float target_k_norm = 10.0f; // Aggressively smaller target
if (k_norm > 0.01f) {
float k_scale = target_k_norm / k_norm;
for (int i = 0; i < nEmbdHeadK; i++) {
state.k.setFloat(i, state.k.getFloat(i) * k_scale);
}
}

Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aggressive K normalization (lines 786-806) appears to be a workaround hack rather than a proper fix. Forcing K vectors to a fixed magnitude of 10.0 is not part of the Gemma 3 specification and may cause incorrect model behavior. This should be removed and the root cause (likely incorrect normalization implementation) should be fixed instead.

Suggested change
// FIX: Apply aggressive K normalization
// Issue: Different token embeddings cause K magnitudes to vary 24% across positions
// Result: Position 1+ attention heavily weighted to position 0 (96.6% vs 3.4%)
// This causes the model to repeat position 0's context instead of generating new tokens
// Solution: Normalize K to fixed magnitude to stabilize attention across positions
float k_norm_sq = 0;
for (int i = 0; i < nEmbdHeadK; i++) {
float k_val = state.k.getFloat(i);
k_norm_sq += k_val * k_val;
}
float k_norm = (float) Math.sqrt(k_norm_sq);
// Apply very aggressive scaling - force to much smaller value
// to ensure softmax doesn't get dominated by position 0
float target_k_norm = 10.0f; // Aggressively smaller target
if (k_norm > 0.01f) {
float k_scale = target_k_norm / k_norm;
for (int i = 0; i < nEmbdHeadK; i++) {
state.k.setFloat(i, state.k.getFloat(i) * k_scale);
}
}
// K normalization workaround removed: ensure K is normalized according to model specification.

Copilot uses AI. Check for mistakes.
Comment on lines 158 to 159
// TODO: FIX THIS
boolean isNvidia = platformName.contains("nvidia") || platformName.contains("cuda") || platformName.contains("ptx");
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment on line 158 indicates this code needs fixing. The logic for detecting NVIDIA platforms should be reviewed and the TODO should be addressed or removed before merging.

Copilot uses AI. Check for mistakes.
Comment on lines +105 to +106
// FIX: Use -1 for contextLength to use the model's default context length
Model model = AOT.tryUsePreLoaded(options.modelPath(), -1);
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using -1 as a magic number for 'use default context length' is not clear. Consider using a named constant like DEFAULT_CONTEXT_LENGTH or Optional.empty() pattern for better code readability.

Copilot uses AI. Check for mistakes.
* Index of the current layer (0-based)
* @return The configured task graph with appropriate data transfer operations
*/
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method overrides TornadoVMQ8_0LayerPlanner<Phi3State,Phi3Configuration,Phi3TornadoWeightsQ8_0>.configureLayerDataTransfers; it is advisable to add an Override annotation.

Copilot uses AI. Check for mistakes.
* @throws UnsupportedOperationException If weights.weightType is not Q8_0 or Q4_0
*/
// @formatter:on
protected TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot uses AI. Check for mistakes.
super(state, model);
}

public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered() {
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method overrides TornadoVMQ8_0LayerPlanner<Phi3State,Phi3Configuration,Phi3TornadoWeightsQ8_0>.setupTornadoForwardPlanLayered; it is advisable to add an Override annotation.

Copilot uses AI. Check for mistakes.
return

version = struct.unpack('<I', f.read(4))[0]
tensor_count = struct.unpack('<Q', f.read(8))[0]
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable tensor_count is not used.

Copilot uses AI. Check for mistakes.
value = f.read(str_len).decode('utf-8')
if 'name' in key.lower() or 'model' in key.lower() or 'arch' in key.lower():
print(f'{key}: {value}')
except:
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except block directly handles BaseException.

Suggested change
except:
except UnicodeDecodeError:

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants