diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f186c2167d7d0..e7f09e6d6ca2b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4043,6 +4043,36 @@ def set_vocab(self): super().set_vocab() +@ModelBase.register("Qwen3NextForCausalLM") +class Qwen3NextModel(Qwen3MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3NEXT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["linear_conv_kernel_dim"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["linear_key_head_dim"])) + self.gguf_writer.add_ssm_group_count(self.find_hparam(["linear_num_key_heads"])) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["linear_num_value_heads"])) + self.gguf_writer.add_ssm_inner_size(self.find_hparam(['linear_value_head_dim']) * self.find_hparam(['linear_num_value_heads'])) + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("mtp"): + return [] # ignore MTP layers for now + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif "conv1d" in name: + data_torch = data_torch.squeeze() + elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"): + data_torch = data_torch + 1 + + yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid) + + @ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index f5f567d4ffa12..529e9987b0197 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -4,6 +4,11 @@ set -e # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}" + +if [ -z "$MODEL_TESTING_PROMPT"]; then + MODEL_TESTING_PROMPT="Hello, my name is" +fi # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then fi echo $CONVERTED_MODEL +echo $MODEL_TESTING_PROMPT cmake --build ../../build --target llama-logits -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is" +../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 7fb55e9af1f52..6d054dc38ec6a 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -185,8 +185,12 @@ def fn(_m, input, output): # of using AutoModelForCausalLM. print(f"Model class: {model.__class__.__name__}") -prompt = "Hello, my name is" -input_ids = tokenizer(prompt, return_tensors="pt").input_ids +device = next(model.parameters()).device +if os.getenv("MODEL_TESTING_PROMPT"): + prompt = os.getenv("MODEL_TESTING_PROMPT") +else: + prompt = "Hello, my name is" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2311cdabe3ba4..47dd8e3ad132f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -475,6 +475,7 @@ extern "C" { GGML_OP_COS, GGML_OP_SUM, GGML_OP_SUM_ROWS, + GGML_OP_CUMSUM, GGML_OP_MEAN, GGML_OP_ARGMAX, GGML_OP_COUNT_EQUAL, @@ -530,6 +531,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_TRI, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -542,6 +544,7 @@ extern "C" { GGML_OP_RWKV_WKV6, GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, + GGML_OP_SOLVE_TRI, GGML_OP_UNARY, @@ -576,6 +579,8 @@ extern "C" { GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, + GGML_UNARY_OP_EXPM1, + GGML_UNARY_OP_SOFTPLUS, GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_XIELU, GGML_UNARY_OP_FLOOR, @@ -620,6 +625,13 @@ extern "C" { GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + enum ggml_tri_type { + GGML_TRI_TYPE_UPPER_DIAG = 0, + GGML_TRI_TYPE_UPPER = 1, + GGML_TRI_TYPE_LOWER_DIAG = 2, + GGML_TRI_TYPE_LOWER = 3 + }; + struct ggml_init_params { // memory pool size_t mem_size; // bytes @@ -957,6 +969,22 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_expm1( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_expm1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_softplus_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sin( struct ggml_context * ctx, struct ggml_tensor * a); @@ -983,6 +1011,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a); + // mean along rows GGML_API struct ggml_tensor * ggml_mean( struct ggml_context * ctx, @@ -2186,6 +2218,17 @@ extern "C" { int shift2, int shift3); + // Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value + GGML_API struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype); + + GGML_API struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype); // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] @@ -2355,6 +2398,11 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * state); + GGML_API struct ggml_tensor * ggml_solve_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * x); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b5466dd703d1d..bebaa3de3f1d7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_sum_rows(params, tensor); } break; + case GGML_OP_CUMSUM: + { + ggml_compute_forward_cumsum(params, tensor); + } break; case GGML_OP_MEAN: { ggml_compute_forward_mean(params, tensor); @@ -1943,6 +1947,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; + case GGML_OP_TRI: + { + ggml_compute_forward_tri(params, tensor); + } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor); @@ -1998,6 +2006,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv7(params, tensor); } break; + case GGML_OP_SOLVE_TRI: + { + ggml_compute_forward_solve_tri(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2153,10 +2165,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_ARGMAX: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: { n_tasks = 1; } break; case GGML_OP_COUNT_EQUAL: + case GGML_OP_SOLVE_TRI: { n_tasks = n_threads; } break; @@ -2179,6 +2194,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_EXPM1: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f66d36ff62c03..27fc81b1f4cd4 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9,6 +9,7 @@ #include #include +#include // ggml_compute_forward_dup @@ -1394,6 +1395,57 @@ void ggml_compute_forward_sum( } } +// ggml_compute_forward_cumsum + +static void ggml_compute_forward_cumsum_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f32(ne00, dst_row, src_row); + } + } + } +} + +void ggml_compute_forward_cumsum( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cumsum_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum_rows static void ggml_compute_forward_sum_rows_f32( @@ -2140,6 +2192,50 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_tri + +static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + float c = ggml_get_op_params_f32(dst, 1); + bool keep_org_val = isnan(c); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + float * src_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + ggml_vec_tri_f32(ne0, i01, dst_ptr, src_ptr, keep_org_val, c, ttype); + } + +} + +void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tri_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_erf static void ggml_compute_forward_gelu_erf_f32( @@ -8721,7 +8817,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = ggml_softplus(dt[h]); + const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -8818,7 +8914,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = ggml_softplus(dt[h]); + const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim @@ -9101,6 +9197,14 @@ void ggml_compute_forward_unary( { ggml_compute_forward_xielu(params, dst); } break; + case GGML_UNARY_OP_EXPM1: + { + ggml_compute_forward_expm1(params, dst); + } break; + case GGML_UNARY_OP_SOFTPLUS: + { + ggml_compute_forward_softplus(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -9697,8 +9801,79 @@ void ggml_compute_forward_gla( } } -// ggml_compute_forward_rwkv_wkv7 +static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular) + const struct ggml_tensor * src1 = dst->src[1]; // B (RHS) + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + // --- Dimension validation --- + GGML_ASSERT(ne00 == ne01); // A must be square + GGML_ASSERT(ne0 == ne00); // solution rows == A rows + GGML_ASSERT(ne1 == ne11); // solution cols == B cols + // Batch dimensions must match + GGML_ASSERT(ne02 == ne12 && ne12 == ne2); + GGML_ASSERT(ne03 == ne13 && ne13 == ne3); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t n = ne00; // system size (A is n x n) + const int64_t k = ne11; // number of RHS columns + const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit + + // chunks per thread + const int64_t dr = (nr + nth - 1)/nth; + + // chunk range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // here be pointers + const float * A = (const float *) src0->data; // [n, n, B1, B2] + const float * B = (const float *) src1->data; // [n, k, B1, B2] + float * X = (float *) dst->data; // [n, k, B1, B2] + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*k); + const int64_t i02 = (ir - i03*ne02*k)/k; + const int64_t i01 = (ir - i03*ne02*k - i02*k); + + const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float); + const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float); + float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float); + + for (int64_t i00 = 0; i00 < n; ++i00) { + float sum = 0.0f; + for (int64_t t = 0; t < i00; ++t) { + sum += A_batch[i00 * n + t] * X_batch[t * k + i01]; + } + float diag = A_batch[i00 * n + i00]; + GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix"); + X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag; + } + } +} + +void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + return ggml_compute_forward_solve_tri_f32(params, dst); + } else { + GGML_ABORT("fatal error"); + } + +} + +// ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( const ggml_compute_params * params, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 9824a03b45833..5709de7abd57e 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -85,6 +86,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, @@ -100,6 +102,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index a047537b34f78..1d9873ad0f230 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -73,6 +73,14 @@ static inline float op_log(float x) { return logf(x); } +static inline float op_expm1(float x) { + return expf(x) - 1.0f; +} + +static inline float op_softplus(float x) { + return (x > 20.0f) ? x : logf(1.0f + expf(x)); +} + static inline float op_floor(float x) { return floorf(x); } @@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * unary_op(params, dst); } +void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index fa45d9f0e636f..bcad5a3af1a98 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 65c7dfb6b9a49..59c350ef5088b 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1416,6 +1416,39 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #endif } +// Applies a triangular mask to the input vector 'src' and writes the result to 'dst'. +// Parameters: +// n - number of elements +// r - current row index +// dst - output array +// src - input array +// keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c' +// c - constant value to use when not keeping original value +// type - type of triangular mask (lower, upper, etc.) +inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + bool cmp = false; + switch (type) { + case GGML_TRI_TYPE_LOWER: cmp = i < r; break; + case GGML_TRI_TYPE_LOWER_DIAG: cmp = i <= r; break; + case GGML_TRI_TYPE_UPPER: cmp = i > r; break; + case GGML_TRI_TYPE_UPPER_DIAG: + default: cmp = i >= r; break; + } + dst[i] = cmp ? (keep_org_val ? src[i] : c) : 0.0f; + } +} + +inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = y[i - 1] + x[i]; + } + } +} + inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 61a8f1df87de1..d8cce9b187bf7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2499,6 +2499,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_XIELU: ggml_cuda_op_xielu(ctx, dst); break; + case GGML_UNARY_OP_EXPM1: + ggml_cuda_op_expm1(ctx, dst); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_softplus(ctx, dst); + break; default: return false; } @@ -3768,6 +3774,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: return ggml_is_contiguous(op->src[0]); default: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5f0d3a6726aef..b19f9d61c311a 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -81,6 +81,14 @@ static __device__ __forceinline__ float op_log(float x) { return logf(x); } +static __device__ __forceinline__ float op_expm1(float x) { + return expf(x) - 1.0f; +} + +static __device__ __forceinline__ float op_softplus(float x) { + return (x > 20.0f) ? x : logf(1.0f + expf(x)); +} + static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } @@ -201,6 +209,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } + +void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6c738cefecfd2..72932c0dbc28d 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -61,6 +61,10 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index ec37a25337b64..fe57d4c582b5f 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } -static inline float ggml_softplus(float input) { +static inline float ggml_compute_softplus_f32(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } // diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9be35c1be8456..882efd90a9ae7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "COS", "SUM", "SUM_ROWS", + "CUMSUM", "MEAN", "ARGMAX", "COUNT_EQUAL", @@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "TRI", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -1002,6 +1004,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RWKV_WKV6", "GATED_LINEAR_ATTN", "RWKV_WKV7", + "TRI_SOLVE", "UNARY", @@ -1019,7 +1022,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1039,6 +1042,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cos(x)", "Σx", "Σx_k", + "cumsum(x)", "Σx/n", "argmax(x)", "count_equal(x)", @@ -1094,6 +1098,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "tri(x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -1106,6 +1111,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rwkv_wkv6(k, v, r, tf, td, s)", "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", + "A X = B, A triangular, solve X", "unary(x)", @@ -1123,7 +1129,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1142,6 +1148,8 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSWISH", "HARDSIGMOID", "EXP", + "EXPM1", + "SOFTPLUS", "GELU_ERF", "XIELU", "FLOOR", @@ -1150,7 +1158,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "TRUNC", }; -static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20"); +static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2258,6 +2266,30 @@ struct ggml_tensor * ggml_log_inplace( return ggml_log_impl(ctx, a, true); } +struct ggml_tensor * ggml_expm1( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1); +} + +struct ggml_tensor * ggml_expm1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1); +} + +struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS); +} + +struct ggml_tensor * ggml_softplus_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS); +} + // ggml_sin static struct ggml_tensor * ggml_sin_impl( @@ -2341,6 +2373,20 @@ struct ggml_tensor * ggml_sum_rows( return result; } +// ggml_cumsum + +struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_CUMSUM; + result->src[0] = a; + + return result; +} + // ggml_mean struct ggml_tensor * ggml_mean( @@ -2668,8 +2714,8 @@ struct ggml_tensor * ggml_xielu( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); - ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); - ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p)); ggml_set_op_params_f32(result, 3, beta); ggml_set_op_params_f32(result, 4, eps); @@ -3516,7 +3562,7 @@ struct ggml_tensor * ggml_reshape_4d( int64_t ne2, int64_t ne3) { GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); @@ -5028,6 +5074,33 @@ struct ggml_tensor * ggml_timestep_embedding( return result; } +// ggml_tri + +struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype) { + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, tritype); + ggml_set_op_params_f32(result, 1, constant); + + result->op = GGML_OP_TRI; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype) { + + return ggml_tri(ctx, a, nan(""), tritype); +} + // ggml_argsort struct ggml_tensor * ggml_argsort( @@ -5882,6 +5955,34 @@ struct ggml_tensor * ggml_opt_step_sgd( return result; } +// solve_tri + +struct ggml_tensor * ggml_solve_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + + // A must be square and lower diagonal + GGML_ASSERT(a->ne[0] == a->ne[1]); + // B must have same outer dimension as A + GGML_ASSERT(a->ne[1] == b->ne[1]); + + // B must be broadcastable to A + GGML_ASSERT(a->ne[2] % b->ne[2] == 0); + GGML_ASSERT(a->ne[3] % b->ne[3] == 0); + + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], a->ne[2], a->ne[3]); + + result->op = GGML_OP_SOLVE_TRI; + result->src[0] = a; + result->src[1] = b; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { @@ -6449,16 +6550,41 @@ static void ggml_compute_backward( ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0)); } } break; - case GGML_UNARY_OP_EXP: { - if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); + case GGML_UNARY_OP_EXP: + { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); + } } - } break; - default: { - fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", - __func__, ggml_unary_op_name(ggml_get_unary_op(tensor))); - GGML_ABORT("fatal error"); - } //break; + break; + case GGML_UNARY_OP_EXPM1: + { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0))); + } + } + break; + case GGML_UNARY_OP_SOFTPLUS: + { + if (src0_needs_grads) { + // gradient of softplus: sigmoid(x) = 1 / (1 + exp(-x)) + struct ggml_tensor * neg_src0 = ggml_neg(ctx, src0); + struct ggml_tensor * exp_neg = ggml_exp(ctx, neg_src0); + struct ggml_tensor * ones = + ggml_exp(ctx, ggml_new_tensor_4d(ctx, src0->type, src0->ne[0], src0->ne[1], src0->ne[2], + src0->ne[3])); + struct ggml_tensor * one_plus_exp = ggml_add(ctx, ones, exp_neg); + struct ggml_tensor * sigmoid = ggml_div(ctx, ones, one_plus_exp); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, sigmoid)); + } + } + break; + default: + { + fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", __func__, + ggml_unary_op_name(ggml_get_unary_op(tensor))); + GGML_ABORT("fatal error"); + } //break; } } break; case GGML_OP_CROSS_ENTROPY_LOSS: { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0d5afa01edf84..d5d375e71ee13 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -352,6 +352,7 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() QWEN3 = auto() QWEN3MOE = auto() + QWEN3NEXT = auto() QWEN3VL = auto() QWEN3VLMOE = auto() PHI2 = auto() @@ -461,6 +462,7 @@ class MODEL_TENSOR(IntEnum): ATTN_NORM_2 = auto() ATTN_OUT_NORM = auto() ATTN_POST_NORM = auto() + ATTN_GATE = auto() ATTN_ROT_EMBD = auto() ATTN_SINKS = auto() FFN_GATE_INP = auto() @@ -513,6 +515,7 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_BETA_ALPHA = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -718,6 +721,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.QWEN3: "qwen3", MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.QWEN3NEXT: "qwen3next", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", MODEL_ARCH.PHI2: "phi2", @@ -830,6 +834,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", + MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", @@ -878,6 +883,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1581,6 +1587,35 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN3NEXT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_BETA_ALPHA, + MODEL_TENSOR.SSM_OUT + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index cef5acec7581f..454262a93f046 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -665,10 +665,11 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", # mamba-hf - "backbone.layers.{bid}.mixer.in_proj", # mamba - "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid - "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next ), MODEL_TENSOR.SSM_CONV1D: ( @@ -676,6 +677,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.conv1d", # mamba "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.conv1d", # plamo2 + "model.layers.{bid}.linear_attn.conv1d", # qwen3next ), MODEL_TENSOR.SSM_X: ( @@ -690,6 +692,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.dt_proj", # mamba "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 + "model.layers.{bid}.linear_attn.dt_proj", # qwen3next ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -702,6 +705,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.A_log", # mamba "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.A_log", # plamo2 + "model.layers.{bid}.linear_attn.A_log", # qwen3next ), MODEL_TENSOR.SSM_B_NORM: ( @@ -724,17 +728,23 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_NORM: ( - "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid - "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.norm", # qwen3next + "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", # mamba-hf "backbone.layers.{bid}.mixer.out_proj", # mamba "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.out_proj", # qwen3next "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), + MODEL_TENSOR.SSM_BETA_ALPHA: ( + "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next + ), + MODEL_TENSOR.TIME_MIX_W0: ( "model.layers.{bid}.attention.w0", # rwkv7 ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7c7953b83dda8..f60a1b752b9b3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -784,6 +785,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3NEXT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_QWEN3VL, { @@ -2489,6 +2522,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2661,6 +2695,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_QWEN3NEXT: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 3f893a2dc6916..d8aad62bf04cf 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -36,6 +36,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, @@ -365,6 +366,7 @@ enum llm_tensor { LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f6192a36e0ee5..f008b7e477fc7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1367,7 +1367,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - return std::max(1024u, 8u*model.n_tensors()); + return std::max(8192, 128u*model.n_tensors()); } llm_graph_result * llama_context::get_gf_res_reserve() const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 539fecb3f7817..814ccc65f3e78 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 +#define LLAMA_MAX_EXPERTS 512 // Qwen3-Next enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 276e1697d466c..dff2b5c0714b4 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1164,3 +1164,7 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { int32_t llama_memory_recurrent_context::s_copy(int i) const { return mem->cells[i + mem->head].src0; } + +bool llama_memory_recurrent_context::has_previous_state() const { + return mem->cells[mem->head].pos >= 0; +} diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d7391248..279229f25e40a 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -163,6 +163,7 @@ class llama_memory_recurrent_context : public llama_memory_context_i { ggml_tensor * get_s_l(int32_t il) const; int32_t s_copy(int i) const; + bool has_previous_state() const; private: const llama_memory_status status; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index aa3a65f87a542..62c083aec4d51 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -813,6 +813,7 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx } struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required) { + LLAMA_LOG_DEBUG("%s: loading tensor %s as view\n", __func__, name.c_str()); const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); if (cur == NULL) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index db53652575278..7986bea1e151a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -118,6 +118,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; @@ -1909,6 +1910,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { // For Granite MoE Shared ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; + case LLM_ARCH_QWEN3NEXT: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + } + + switch (hparams.n_layer) { + case 80: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_CHAMELEON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2550,6 +2574,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_QWEN3NEXT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_projection_size = key_dim * 2 + value_dim * 2; + const int64_t ba_projection_size = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_projection_size }, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + } + } + break; case LLM_ARCH_LLADA: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -6530,7 +6623,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H) { + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_QWEN3NEXT) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -20287,6 +20381,742 @@ struct llm_build_cogvlm : public llm_graph_context { } }; +struct llm_build_qwen3next : public llm_graph_context_mamba { + llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "model.embed_tokens", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = ggml_tri(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f, GGML_TRI_TYPE_LOWER); + ggml_tensor * identity = ggml_diag(ctx0, ggml_scale_bias_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 0.0f, 1.0f)); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + cur = build_q3n_norm(inpL, model.layers[il].attn_norm, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, causal_mask, identity, il); + } else { + // Full attention layer + cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_q3n_norm(cur, model.layers[il].attn_post_norm, il); + cb(attn_post_norm, "attn_post_norm", il); + + // FFN layer (MoE or dense) - without residual connection + cur = build_layer_ffn(attn_post_norm, model, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_q3n_norm(cur, model.output_norm, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * delta_net_unified( + ggml_context * ctx, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + bool use_qk_l2norm, + float eps_norm, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + if (use_qk_l2norm) { + q = ggml_l2_norm(ctx, q, eps_norm); + k = ggml_l2_norm(ctx, k, eps_norm); + } + + float scale = 1.0f / sqrtf(S_v); + q = ggml_scale(ctx, q, scale); + + beta = ggml_sigmoid(ctx, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx, ggml_permute(ctx, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx, ggml_permute(ctx, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + ggml_tensor * v_beta = ggml_mul(ctx, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx, k, beta); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx, g); + + cb(k_beta, "k_beta", il); + cb(v_beta, "v_beta", il); + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v, + n_seqs); // [chunk_size, 1, n_tokens, n_seqs] + ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v, + n_seqs); // [1, chunk_size, n_tokens, n_seqs] + + // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs] + // ggml_tensor * gcs_i_broadcast = + // ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, + // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + // Don't need this, this one will get auto-broadcast + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx, gcs_j, n_tokens, n_tokens, H_v, + n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + + ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i); + + // Apply lower triangular mask to ensure attention is causal (only past tokens influence current) + decay_mask = ggml_mul(ctx, decay_mask, causal_diag_mask); + // Apply exponential to get the decay mask values + decay_mask = ggml_exp(ctx, decay_mask); + // Apply lower triangular mask again to ensure only lower triangular values remain + decay_mask = ggml_mul(ctx, decay_mask, causal_diag_mask); + + cb(decay_mask, "decay_mask", il); + + // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, k, k_beta); + + cb(kmulkbeta, "kmulkbeta", il); + + ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask)); + + cb(attn, "attn_pre_rec", il); + + // for i in range(1, chunk_size): + // row = attn[..., i, :i].clone() + // sub = attn[..., :i, :i].clone() + // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + // + // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) + ggml_tensor * attn_lower = ggml_mul(ctx, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx, ggml_repeat(ctx, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn); + attn = ggml_mul(ctx, lin_solve, causal_mask); + attn = ggml_add(ctx, attn, identity); + + // value = attn @ v_beta + v = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)), attn); + + cb(v, "value_beta", il); + + // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t); + + cb(gexp, "g_cum_exp", il); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx, k, q); + attn = ggml_mul(ctx, attn, decay_mask); + attn = ggml_mul(ctx, attn, ggml_add(ctx, identity, causal_mask)); + + cb(attn, "attn_decay_key", il); + + ggml_tensor * state_t = ggml_cont(ctx, ggml_transpose(ctx, state)); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx, state_t, k_cumdecay); + + cb(v_prime, "v_prime", il); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat(ctx, v, v_prime), v_prime); + + ggml_tensor * v_new_t = ggml_cont(ctx, ggml_transpose(ctx, v_new)); + + cb(v_new, "v_new", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp); + ggml_tensor * attn_inter = ggml_mul_mat(ctx, state_t, q_g_exp); + + cb(attn_inter, "attn_inter", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx, v_new_t, attn); + + cb(v_attn, "v_attn", il); + + ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn); + + cb(core_attn_out, "core_attn_out", il); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = ggml_cont(ctx, ggml_view_4d(ctx, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], g_cumsum_t->nb[1], + g_cumsum_t->nb[2], g_cumsum_t->nb[3], g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1))); + + cb(g_cum_last, "g_cum_last", il); + + ggml_tensor * gexp_last = ggml_reshape_4d(ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(g_cum_last, "gexp_last", il); + + ggml_tensor * g_cum_last_3d = ggml_reshape_3d(ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(g_cum_last, "g_cum_last_3d", il); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]); + + cb(g_cum_last, "g_cumsum_3d", il); + + ggml_tensor * g_diff = ggml_neg(ctx, ggml_sub(ctx, g_cumsum_3d, g_cum_last_3d)); + + cb(g_cum_last, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx, g_diff); + + cb(g_cum_last, "g_diff_exp", il); + + ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_reshape_4d(ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + cb(g_cum_last, "key_gdiff", il); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, v_new_t, + ggml_cont(ctx, ggml_transpose(ctx, key_gdiff))); + + cb(kgdmulvnew, "kgdmulvnew", il); + + ggml_tensor * new_state = ggml_add(ctx, ggml_mul(ctx, state, gexp_last), kgdmulvnew); + + cb(new_state, "new_state", il); + + // flatten output + ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx, flat_output, flat_state, 0); + } + + ggml_tensor * build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer) { + // ggml_tensor * input_norm = ggml_scale_bias(ctx0, weights, 1.0f, 1.0f); + // EDIT: we moved the shifting part to the conversion, so we just call normal build_norm + return build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + } + + ggml_tensor * build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + return ggml_mul(ctx0, normalized, gated_silu); + } + + ggml_tensor * build_qwen3next_attention_layer(ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + struct ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur_full, "Qcur_full", il); + Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); + // Split Q projection into query and gate + // The split should be along dimension 0 (the feature dimension) + struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], + Qcur_full->nb[2], Qcur_full->nb[3], 0); + struct ggml_tensor * gate = + ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], + Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); + cb(Qcur, "Qcur", il); + cb(gate, "gate", il); + + // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_q3n_norm(Qcur, model.layers[il].attn_q_norm, il); + cb(Qcur, "Qcur_normed", il); + + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il); + cb(Kcur, "Kcur_normed", il); + + // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = + hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + struct ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; + } + + ggml_tensor * build_qwen3next_linear_attn_layer( + llm_graph_input_rs * inp, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_ba into b and a (beta and alpha parameters) + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, // beta size + num_v_heads / num_k_heads // alpha size + }; + + ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); + cb(b, "b", il); + + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], + split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); + cb(a, "a", il); + + // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] + ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); + + GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * sizeof(float)); + cb(key, "k", il); + + ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); + cb(z, "z", il); + + GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + + ggml_nelements(z) == + ggml_nelements(mixed_qkvz)); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], conv_input->nb[2], + (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); + cb(conv_output_proper, "conv_output_pre_silu", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); + cb(conv_qkv_mix, "conv_qkv_mix", il); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, + conv_qkv_mix->nb[1], 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, + conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, + conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + int64_t repeat_factor = num_v_heads / num_k_heads; + + // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + + // Repeat along the third dimension (the new dimension with size 1) + ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + + // Reshape back to merge the head and repeat dimensions + // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] + // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between delta_net and delta_net_recurrent based on generation mode + + // if (use_precomputed_states) { + // // Use delta_net_recurrent for single token generation + // attn_out = ggml_delta_net_recurrent(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps); + // } else { + // // Use regular delta_net for prompt processing + // // attn_out = ggml_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps); + ggml_tensor * attn_out = delta_net_unified(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, true, hparams.f_norm_rms_eps, il); + //} + cb(attn_out, "attn_out", il); + + // The tensors were concatenated 1d, so we need to extract them 1d as well + const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; + ggml_tensor * attn_out_1d = + ggml_view_1d(ctx0, attn_out, output_flat_size, 0); + cb(attn_out_1d, "attn_out_1d", il); + + ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + cb(attn_out_final, "attn_out_reshaped", il); + + // Extract the state part (second part of the concatenated tensor) + // State starts after n_tokens elements along dimension 1 + const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; + + ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); + cb(state_1d, "state_1d", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, state_1d, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; + } + + ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) { + + // Check if this is an MoE layer + if (model.layers[il].ffn_gate_inp != nullptr) { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert, + n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + // The gate needs to be broadcast to match the dimensions of ffn_shexp + // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] + // We need to repeat the gate along the feature dimension + shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); + cb(shared_gate, "shared_expert_gate_broadcast", il); + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } else { + // Dense FFN branch (not currently used I believe) + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + return cur; + } +}; + + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { llama_memory_i * res; @@ -20815,6 +21645,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN3NEXT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_APERTUS: { llm = std::make_unique(*this, params); @@ -21010,6 +21844,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: diff --git a/src/llama-model.h b/src/llama-model.h index 71ff148e07dae..f3733285e9e8d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -112,6 +112,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 @@ -307,6 +308,9 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // qwen3next + struct ggml_tensor * ssm_beta_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a56b2626ae1c5..4b6545256c336 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -667,6 +667,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::map mapped; int blk_id = 0; int pruned_attention_w = 0; + int linear_layers = 0; // make a list of weights std::vector tensors; @@ -684,6 +685,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else if (remapped_name != it.first) { ggml_set_name(it.second.tensor, remapped_name.c_str()); LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); + } else if (it.first.find("ssm_conv") != std::string::npos) { + linear_layers++; } tensors.push_back(&it.second); } @@ -732,7 +735,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // for each decoder block, there are 2 attention layers n_attn_layer += 2 * model.hparams.dec_n_layer; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w - linear_layers) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 04fa1b62d3b4d..5746309e6746d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3809,6 +3809,150 @@ struct test_cos : public test_case { } }; +// GGML_OP_EXPM1 +struct test_expm1 : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_expm1(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 3, 3, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_expm1(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // Use small values to avoid overflow in expm1 + init_tensor_uniform(t, -2.0f, 2.0f); + } + } + + bool grad_precise() override { + return true; + } +}; + +// GGML_OP_SOFTPLUS +struct test_softplus : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_softplus(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 3, 3, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_softplus(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // Use values around the threshold (20) to test both branches of softplus + init_tensor_uniform(t, -25.0f, 25.0f); + } + } + + bool grad_precise() override { + return true; + } +}; + +// GGML_OP_EXPM1_INPLACE +struct test_expm1_inplace : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_expm1_inplace(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 3, 3, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_expm1_inplace(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // Use small values to avoid overflow in expm1 + init_tensor_uniform(t, -2.0f, 2.0f); + } + } + + bool grad_precise() override { + return true; + } +}; + +// GGML_OP_SOFTPLUS_INPLACE +struct test_softplus_inplace : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_softplus_inplace(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 3, 3, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_softplus_inplace(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // Use values around the threshold (20) to test both branches of softplus + init_tensor_uniform(t, -25.0f, 25.0f); + } + } + + bool grad_precise() override { + return true; + } +}; + // GGML_OP_CLAMP struct test_clamp : public test_case { const ggml_type type; @@ -7006,6 +7150,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sqr (type)); test_cases.emplace_back(new test_sqrt (type)); test_cases.emplace_back(new test_log (type)); + test_cases.emplace_back(new test_expm1 (type)); + test_cases.emplace_back(new test_softplus (type)); + test_cases.emplace_back(new test_expm1_inplace (type)); + test_cases.emplace_back(new test_softplus_inplace (type)); test_cases.emplace_back(new test_sin (type)); test_cases.emplace_back(new test_cos (type)); test_cases.emplace_back(new test_clamp (type)); @@ -7017,6 +7165,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_log (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_expm1 (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_softplus (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_expm1_inplace (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_softplus_inplace (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 498e00e3a5e58..4735f939f6f07 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -14,6 +14,11 @@ #include #include #include +#include + +// Forward declarations for internal cache access +struct llama_memory_hybrid; +struct llama_memory_recurrent; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include