@@ -864,6 +864,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
864864 default : type = LLM_TYPE_UNKNOWN;
865865 }
866866 } break ;
867+ case LLM_ARCH_GEMMA3:
868+ {
869+ ml.get_key (LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa );
870+ ml.get_key (LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps );
871+
872+ switch (hparams.n_layer ) {
873+ case 26 : type = LLM_TYPE_1B; break ;
874+ case 34 : type = LLM_TYPE_4B; break ;
875+ case 48 : type = LLM_TYPE_12B; break ;
876+ case 62 : type = LLM_TYPE_27B; break ;
877+ default : type = LLM_TYPE_UNKNOWN;
878+ }
879+
880+ hparams.f_attention_scale = type == LLM_TYPE_27B
881+ ? 1 .0f / sqrtf (float (hparams.n_embd / hparams.n_head (0 )))
882+ : 1 .0f / sqrtf (float (hparams.n_embd_head_k ));
883+ } break ;
867884 case LLM_ARCH_STARCODER2:
868885 {
869886 ml.get_key (LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps );
@@ -2454,6 +2471,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24542471 layer.wo = create_tensor (tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd_head_k * n_head, n_embd}, 0 );
24552472 layer.attn_post_norm = create_tensor (tn (LLM_TENSOR_ATTN_POST_NORM, " weight" , i), {n_embd}, 0 );
24562473
2474+ layer.ffn_norm = create_tensor (tn (LLM_TENSOR_FFN_NORM, " weight" , i), {n_embd}, 0 );
2475+ layer.ffn_gate = create_tensor (tn (LLM_TENSOR_FFN_GATE, " weight" , i), {n_embd, n_ff}, 0 );
2476+ layer.ffn_up = create_tensor (tn (LLM_TENSOR_FFN_UP, " weight" , i), {n_embd, n_ff}, 0 );
2477+ layer.ffn_down = create_tensor (tn (LLM_TENSOR_FFN_DOWN, " weight" , i), { n_ff, n_embd}, 0 );
2478+ layer.ffn_post_norm = create_tensor (tn (LLM_TENSOR_FFN_POST_NORM, " weight" , i), {n_embd}, 0 );
2479+ }
2480+ } break ;
2481+ case LLM_ARCH_GEMMA3:
2482+ {
2483+ tok_embd = create_tensor (tn (LLM_TENSOR_TOKEN_EMBD, " weight" ), {n_embd, n_vocab}, 0 );
2484+
2485+ // output
2486+ output_norm = create_tensor (tn (LLM_TENSOR_OUTPUT_NORM, " weight" ), {n_embd}, 0 );
2487+ output = create_tensor (tn (LLM_TENSOR_TOKEN_EMBD, " weight" ), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
2488+
2489+ for (int i = 0 ; i < n_layer; ++i) {
2490+ auto & layer = layers[i];
2491+
2492+ layer.attn_norm = create_tensor (tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd}, 0 );
2493+
2494+ layer.wq = create_tensor (tn (LLM_TENSOR_ATTN_Q, " weight" , i), {n_embd, n_embd_head_k * n_head}, 0 );
2495+ layer.wk = create_tensor (tn (LLM_TENSOR_ATTN_K, " weight" , i), {n_embd, n_embd_k_gqa}, 0 );
2496+ layer.wv = create_tensor (tn (LLM_TENSOR_ATTN_V, " weight" , i), {n_embd, n_embd_v_gqa}, 0 );
2497+ layer.wo = create_tensor (tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd_head_k * n_head, n_embd}, 0 );
2498+
2499+ layer.attn_post_norm = create_tensor (tn (LLM_TENSOR_ATTN_POST_NORM, " weight" , i), {n_embd}, 0 );
2500+ layer.attn_k_norm = create_tensor (tn (LLM_TENSOR_ATTN_K_NORM, " weight" , i), {n_embd_head_k}, 0 );
2501+ layer.attn_q_norm = create_tensor (tn (LLM_TENSOR_ATTN_Q_NORM, " weight" , i), {n_embd_head_k}, 0 );
2502+
24572503 layer.ffn_norm = create_tensor (tn (LLM_TENSOR_FFN_NORM, " weight" , i), {n_embd}, 0 );
24582504 layer.ffn_gate = create_tensor (tn (LLM_TENSOR_FFN_GATE, " weight" , i), {n_embd, n_ff}, 0 );
24592505 layer.ffn_up = create_tensor (tn (LLM_TENSOR_FFN_UP, " weight" , i), {n_embd, n_ff}, 0 );
@@ -3650,6 +3696,7 @@ void llama_model::print_info() const {
36503696 LLAMA_LOG_INFO (" %s: f_clamp_kqv = %.1e\n " , __func__, hparams.f_clamp_kqv );
36513697 LLAMA_LOG_INFO (" %s: f_max_alibi_bias = %.1e\n " , __func__, hparams.f_max_alibi_bias );
36523698 LLAMA_LOG_INFO (" %s: f_logit_scale = %.1e\n " , __func__, hparams.f_logit_scale );
3699+ LLAMA_LOG_INFO (" %s: f_attn_scale = %.1e\n " , __func__, hparams.f_attention_scale );
36533700 LLAMA_LOG_INFO (" %s: n_ff = %s\n " , __func__, print_f ([&](uint32_t il) { return hparams.n_ff (il); }, hparams.n_layer ).c_str ());
36543701 LLAMA_LOG_INFO (" %s: n_expert = %u\n " , __func__, hparams.n_expert );
36553702 LLAMA_LOG_INFO (" %s: n_expert_used = %u\n " , __func__, hparams.n_expert_used );
@@ -3923,6 +3970,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
39233970 case LLM_ARCH_PHIMOE:
39243971 case LLM_ARCH_GEMMA:
39253972 case LLM_ARCH_GEMMA2:
3973+ case LLM_ARCH_GEMMA3:
39263974 case LLM_ARCH_STARCODER2:
39273975 case LLM_ARCH_OPENELM:
39283976 case LLM_ARCH_GPTNEOX:
0 commit comments