From 245f39157611918962ded35a425fb7501f898f9a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:33:30 +0300 Subject: [PATCH 01/64] graph : reuse hybrid graphs --- src/llama-graph.cpp | 41 ++++++++++++++++++++++++++++++++++--- src/llama-graph.h | 10 +++++++-- src/llama-memory-hybrid.cpp | 2 +- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a24853c63ada4..85ceadd077e13 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -436,8 +436,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + return res; } // @@ -1848,7 +1883,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index dc84b7942893a..4a810a4e9c267 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -360,22 +360,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01bdf..a1b45e4a3cce3 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } From 638e2c23985e52fe1a326e0038c232ba084856d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:36:17 +0300 Subject: [PATCH 02/64] graph : reuse recurrent graphs --- src/llama-graph.cpp | 15 +++++++++++++++ src/llama-graph.h | 2 ++ 2 files changed, 17 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 85ceadd077e13..7f0c974f1760b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); diff --git a/src/llama-graph.h b/src/llama-graph.h index 4a810a4e9c267..394e884323bc1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph From 0b9c1ae3d8c6b272fcce749d8e0cabe6e6d7de68 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 21:44:08 +0300 Subject: [PATCH 03/64] metal : fix mul-mm condition + fix mul-mv permuted kernels --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 66 +++++++++++++++----------- src/llama-model.cpp | 8 ++-- 3 files changed, 44 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 1137e210773af..5f9370449bb2d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { !ggml_is_transposed(op->src[1]) && // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - props_dev->has_simdgroup_mm && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 45d91def88bf2..ddc285042d284 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7517,6 +7516,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7524,24 +7526,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7731,6 +7736,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7738,20 +7746,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5fe5b749c355..36d495d6cfeab 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) { From 1f02d9337a51332fbd88aad2e276800c12f3f5c0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:44:41 +0300 Subject: [PATCH 04/64] graph : fix reuse check for recurrent inputs --- src/llama-graph.cpp | 11 ++++++++++- src/llama-graph.h | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7f0c974f1760b..a946e5d8654a7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + return res; } @@ -487,6 +490,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + return res; } @@ -1827,6 +1833,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1895,7 +1904,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 394e884323bc1..a596461bb9928 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -234,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { From 00f115fe810815d4a22a6dee0acc346131e970e1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:57:35 +0300 Subject: [PATCH 05/64] memory : move the recurrent state into the memory context --- src/llama-graph.cpp | 13 ++++++++----- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 ++++++++++------- src/llama-memory-recurrent.h | 6 ++++-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a946e5d8654a7..0d5a9a6c06cc4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : + mctx(mctx), + head(mctx->get_head()), + rs_z(mctx->get_rs_z()) { +} + void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= head == mctx->get_head(); - res &= rs_z == mctx->get_rs_z(); + res &= this->head == mctx->get_head(); + res &= this->rs_z == mctx->get_rs_z(); return res; } @@ -1833,9 +1839,6 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); - inp->head = mctx_cur->get_head(); - inp->rs_z = mctx_cur->get_rs_z(); - return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index a596461bb9928..90aadcd9caa09 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx); virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // used in view offsets, need to match for valid graph reuse - uint32_t head; - int32_t rs_z; + // need to match for valid graph reuse + const uint32_t head; + const int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47b87..28d1b2a623901 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1088,12 +1088,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), + n_rs(mem->size), head(0), rs_z(0), size(mem->size) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), + n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { +} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1134,19 +1137,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return is_full ? mem->size : mem->n; + return n_rs; } uint32_t llama_memory_recurrent_context::get_head() const { - return is_full ? 0 : mem->head; + return head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return is_full ? 0 : mem->rs_z; + return rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return mem->size; + return size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1158,5 +1161,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + return mem->cells[i + head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce938d..c99b155bcbc42 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,8 +175,10 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: - // TODO: extract all the state like `head` and `n` here // - const bool is_full = false; + const uint32_t n_rs = 0; + const uint32_t head = 0; + const int32_t rs_z = -1; + const uint32_t size = 0; }; From 2744d61185b87cb0f409915675af730e9b9624a5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 19:41:10 +0300 Subject: [PATCH 06/64] Revert "memory : move the recurrent state into the memory context" This reverts commit 00f115fe810815d4a22a6dee0acc346131e970e1. --- src/llama-graph.cpp | 13 +++++-------- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 +++++++---------- src/llama-memory-recurrent.h | 6 ++---- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0d5a9a6c06cc4..a946e5d8654a7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,12 +235,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : - mctx(mctx), - head(mctx->get_head()), - rs_z(mctx->get_rs_z()) { -} - void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -269,8 +263,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= this->head == mctx->get_head(); - res &= this->rs_z == mctx->get_rs_z(); + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); return res; } @@ -1839,6 +1833,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 90aadcd9caa09..a596461bb9928 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx); + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // need to match for valid graph reuse - const uint32_t head; - const int32_t rs_z; + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 28d1b2a623901..d67f5a5f47b87 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1088,15 +1088,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), - n_rs(mem->size), head(0), rs_z(0), size(mem->size) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), - n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { -} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1137,19 +1134,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return n_rs; + return is_full ? mem->size : mem->n; } uint32_t llama_memory_recurrent_context::get_head() const { - return head; + return is_full ? 0 : mem->head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return rs_z; + return is_full ? 0 : mem->rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return size; + return mem->size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1161,5 +1158,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + head].src0; + return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index c99b155bcbc42..077c6e3ce938d 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,10 +175,8 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here // - const uint32_t n_rs = 0; - const uint32_t head = 0; - const int32_t rs_z = -1; - const uint32_t size = 0; + const bool is_full = false; }; From 8c23c43588ed38b8fdfcfb8cf8b90f77c3570ea7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 10 Oct 2025 15:29:35 -0600 Subject: [PATCH 07/64] Added: tri, cumsum. Still a mess. Cherry-picked and edited from 7ec2df64a46f4697d9c95f3f07753c3e3b1926fa The original commit contained the DELTA_NET op as well which I've removed in this cherry-picked version. Co-Authored-By: Piotr Wilkin Signed-off-by: Gabe Goodhart --- ggml/include/ggml.h | 24 +++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 10 ++++ ggml/src/ggml-cpu/ops.cpp | 95 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 2 + ggml/src/ggml-cpu/vec.h | 32 ++++++++++++ ggml/src/ggml.c | 49 ++++++++++++++++++- 6 files changed, 210 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 60c6b63d05978..beb7ee988097a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -474,6 +474,7 @@ extern "C" { GGML_OP_COS, GGML_OP_SUM, GGML_OP_SUM_ROWS, + GGML_OP_CUMSUM, GGML_OP_MEAN, GGML_OP_ARGMAX, GGML_OP_COUNT_EQUAL, @@ -529,6 +530,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_TRI, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -615,6 +617,13 @@ extern "C" { GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + enum ggml_tri_type { + GGML_TRI_TYPE_UPPER_DIAG = 0, + GGML_TRI_TYPE_UPPER = 1, + GGML_TRI_TYPE_LOWER_DIAG = 2, + GGML_TRI_TYPE_LOWER = 3 + }; + struct ggml_init_params { // memory pool size_t mem_size; // bytes @@ -978,6 +987,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a); + // mean along rows GGML_API struct ggml_tensor * ggml_mean( struct ggml_context * ctx, @@ -2141,6 +2154,17 @@ extern "C" { int shift2, int shift3); + // Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value + GGML_API struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype); + + GGML_API struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype); // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index eded6eb77ed69..2db1bf19e3227 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_sum_rows(params, tensor); } break; + case GGML_OP_CUMSUM: + { + ggml_compute_forward_cumsum(params, tensor); + } break; case GGML_OP_MEAN: { ggml_compute_forward_mean(params, tensor); @@ -1943,6 +1947,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; + case GGML_OP_TRI: + { + ggml_compute_forward_tri(params, tensor); + } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor); @@ -2153,6 +2161,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_ARGMAX: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8e1a2de14f983..9af57b4d3b695 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9,6 +9,7 @@ #include #include +#include // ggml_compute_forward_dup @@ -1394,6 +1395,57 @@ void ggml_compute_forward_sum( } } +// ggml_compute_forward_cumsum + +static void ggml_compute_forward_cumsum_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f32(ne00, dst_row, src_row); + } + } + } +} + +void ggml_compute_forward_cumsum( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cumsum_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum_rows static void ggml_compute_forward_sum_rows_f32( @@ -2140,6 +2192,49 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_tri + +static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + float c = *((float *) &(dst->op_params[1])); + bool keep_org_val = isnan(c); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype); + } + +} + +void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tri_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_erf static void ggml_compute_forward_gelu_erf_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 9824a03b45833..d6f8dedcd9c55 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -85,6 +86,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index f95ca94e54b16..3e60f9e3d9020 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1414,6 +1414,38 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #endif } +// Applies a triangular mask to the input vector 'src' and writes the result to 'dst'. +// Parameters: +// n - number of elements +// r - current row index +// dst - output array +// src - input array +// keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c' +// c - constant value to use when not keeping original value +// type - type of triangular mask (lower, upper, etc.) +inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + bool cmp; + switch (type) { + case GGML_TRI_TYPE_LOWER: cmp = i < r; break; + case GGML_TRI_TYPE_LOWER_DIAG: cmp = i <= r; break; + case GGML_TRI_TYPE_UPPER: cmp = i > r; break; + case GGML_TRI_TYPE_UPPER_DIAG: cmp = i >= r; break; + } + dst[i] = cmp ? (keep_org_val ? src[i] : c) : 0.0f; + } +} + +inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = y[i - 1] + x[i]; + } + } +} + inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bce1375ba3c0..7e4bfb07154b3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "COS", "SUM", "SUM_ROWS", + "CUMSUM", "MEAN", "ARGMAX", "COUNT_EQUAL", @@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "TRI", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -1019,7 +1021,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1039,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cos(x)", "Σx", "Σx_k", + "cumsum(x)", "Σx/n", "argmax(x)", "count_equal(x)", @@ -1094,6 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "tri(x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -1123,7 +1127,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2337,6 +2341,20 @@ struct ggml_tensor * ggml_sum_rows( return result; } +// ggml_cumsum + +struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_CUMSUM; + result->src[0] = a; + + return result; +} + // ggml_mean struct ggml_tensor * ggml_mean( @@ -4968,6 +4986,33 @@ struct ggml_tensor * ggml_timestep_embedding( return result; } +// ggml_tri + +struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype) { + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, tritype); + ggml_set_op_params_f32(result, 1, constant); + + result->op = GGML_OP_TRI; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype) { + + return ggml_tri(ctx, a, nan(""), tritype); +} + // ggml_argsort struct ggml_tensor * ggml_argsort( From 2a2e79cdc2707eb392e65be09ee8fc5f70d81e19 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 10 Oct 2025 15:50:32 -0600 Subject: [PATCH 08/64] feat(tests): Add --verbose | -v flag to test-backend-ops to print tensors Branch: Mamba2Perf Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 180 +++++++++++++++++++++++++++++++------ 1 file changed, 155 insertions(+), 25 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fa16b497a6b7..ae45bc244d3c9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -175,6 +175,33 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t)); } +static std::vector ggml_get_float_value(uint8_t * buf, ggml_type type, size_t i, size_t bs, + bool quantized, std::vector & vq) { + const auto * tt = ggml_get_type_traits(type); + std::vector tv; + if (type == GGML_TYPE_F16) { + tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); + } else if (type == GGML_TYPE_BF16) { + tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); + } else if (type == GGML_TYPE_F32) { + tv.push_back(*(float *) &buf[i]); + } else if (type == GGML_TYPE_I64) { + tv.push_back((float)*(int64_t *) &buf[i]); + } else if (type == GGML_TYPE_I32) { + tv.push_back((float)*(int32_t *) &buf[i]); + } else if (type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); + } else if (quantized) { + tt->to_float(&buf[i], vq.data(), bs); + tv.insert(tv.end(), vq.begin(), vq.end()); + } else { + GGML_ABORT("fatal error"); + } + return tv; +} + static std::vector tensor_to_float(const ggml_tensor * t) { std::vector tv; tv.reserve(ggml_nelements(t)); @@ -182,7 +209,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { std::vector buf(ggml_nbytes(t)); ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); - const auto * tt = ggml_get_type_traits(t->type); size_t bs = ggml_blck_size(t->type); std::vector vq(ggml_blck_size(t->type)); bool quantized = ggml_is_quantized(t->type); @@ -193,26 +219,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; - if (t->type == GGML_TYPE_F16) { - tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); - } else if (t->type == GGML_TYPE_BF16) { - tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); - } else if (t->type == GGML_TYPE_F32) { - tv.push_back(*(float *) &buf[i]); - } else if (t->type == GGML_TYPE_I64) { - tv.push_back((float)*(int64_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I32) { - tv.push_back((float)*(int32_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I16) { - tv.push_back((float)*(int16_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I8) { - tv.push_back((float)*(int8_t *) &buf[i]); - } else if (quantized) { - tt->to_float(&buf[i], vq.data(), bs); - tv.insert(tv.end(), vq.begin(), vq.end()); - } else { - GGML_ABORT("fatal error"); - } + const auto fvs = ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq); + tv.insert(tv.end(), fvs.begin(), fvs.end()); } } } @@ -221,6 +229,107 @@ static std::vector tensor_to_float(const ggml_tensor * t) { return tv; } +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void ggml_print_tensor(ggml_tensor * t, int64_t n = 3) { + GGML_ASSERT(t != nullptr); + GGML_ASSERT(n > 0); + + std::stringstream src_ss; + src_ss << "("; + size_t last_src = 0; + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (t->src[i] != nullptr) { + last_src = i; + } + } + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (t->src[i] != nullptr) { + src_ss << t->src[i]->name << "{" << ggml_ne_string(t->src[i]) <<"}"; + } + if (i <= last_src) { + src_ss << ", "; + } + } + src_ss << ")"; + + printf("%s: %24s = (%s) %10s%s = {%s}\n", __func__, + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src_ss.str().c_str(), + ggml_ne_string(t).c_str()); + + std::vector tv; + tv.reserve(ggml_nelements(t)); + + std::vector buf(ggml_nbytes(t)); + ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); + + size_t bs = ggml_blck_size(t->type); + std::vector vq(ggml_blck_size(t->type)); + bool quantized = ggml_is_quantized(t->type); + + float sum = 0; + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + for (const auto & val : ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq)) { + sum += val; + } + } + } + } + } + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + printf(" [\n"); + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + if (i2 == n && t->ne[2] > 2*n) { + printf(" ..., \n"); + i2 = t->ne[2] - n; + } + printf(" [\n"); + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + if (i1 == n && t->ne[1] > 2*n) { + printf(" ..., \n"); + i1 = t->ne[1] - n; + } + printf(" ["); + for (int64_t i0 = 0; i0 < t->ne[0]; i0++) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + if (i0 == n && t->ne[0] > 2*n) { + printf("..., "); + i0 = t->ne[0] - n; + } + for (const auto & v : ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq)) { + printf("%12.4f", v); + } + if (i0 < t->ne[0] - 1) printf(", "); + } + printf("],\n"); + } + printf(" ],\n"); + } + printf(" ]\n"); + printf(" sum = %f\n", sum); + } + + // TODO: make this abort configurable/optional? + if (std::isnan(sum)) { + printf("encountered NaN - aborting\n"); + exit(0); + } +} + // normalized mean squared error = mse(a, b) / mse(a, 0) static double nmse(const float * a, const float * b, size_t n) { double mse_a_b = 0.0; @@ -980,6 +1089,8 @@ static std::unique_ptr create_printer(output_formats format) { GGML_ABORT("invalid output format"); } +// test case definition + struct test_case { virtual ~test_case() {} @@ -1056,6 +1167,9 @@ struct test_case { std::vector sentinels; + // set to true to print tensors + bool verbose = false; + void add_sentinel(ggml_context * ctx) { if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { return; @@ -1201,6 +1315,7 @@ struct test_case { // compare struct callback_userdata { bool ok; + bool verbose; double max_err; ggml_backend_t backend1; ggml_backend_t backend2; @@ -1208,6 +1323,7 @@ struct test_case { callback_userdata ud { true, + verbose, max_nmse_err(), backend1, backend2 @@ -1232,6 +1348,11 @@ struct test_case { } } + if (ud->verbose) { + ggml_print_tensor(t1); + ggml_print_tensor(t2); + } + std::vector f1 = tensor_to_float(t1); std::vector f2 = tensor_to_float(t2); @@ -5769,7 +5890,7 @@ static const ggml_type other_types[] = { }; // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low -static std::vector> make_test_cases_eval() { +static std::vector> make_test_cases_eval(bool verbose = false) { std::vector> test_cases; std::default_random_engine rng(0); @@ -6824,6 +6945,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_falcon(2)); #endif + // set verbose on all test cases + for (auto & tc : test_cases) { + tc->verbose = verbose; + } + return test_cases; } @@ -6977,7 +7103,7 @@ static std::vector> make_test_cases_perf() { } static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter, - printer * output_printer) { + printer * output_printer, bool verbose) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -6996,7 +7122,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op }; if (mode == MODE_TEST) { - auto test_cases = make_test_cases_eval(); + auto test_cases = make_test_cases_eval(verbose); filter_test_cases(test_cases, params_filter); ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL); if (backend_cpu == NULL) { @@ -7158,6 +7284,7 @@ static void usage(char ** argv) { printf(" --output specifies output format (default: console, options: console, sql, csv)\n"); printf(" --list-ops lists all available GGML operations\n"); printf(" --show-coverage shows test coverage\n"); + printf(" --verbose | -v print tensors during ops\n"); } int main(int argc, char ** argv) { @@ -7166,6 +7293,7 @@ int main(int argc, char ** argv) { const char * op_names_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; + bool verbose = false; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "test") == 0) { @@ -7213,6 +7341,8 @@ int main(int argc, char ** argv) { } else if (strcmp(argv[i], "--show-coverage") == 0) { show_test_coverage(); return 0; + } else if (strcmp(argv[i], "--verbose") == 0 || strcmp(argv[i], "-v") == 0) { + verbose = true; } else { usage(argv); return 1; @@ -7265,7 +7395,7 @@ int main(int argc, char ** argv) { false, "", ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024, true)); - bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get()); + bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), verbose); if (ok) { n_ok++; From 092f74040b953e9883938ad602732a67e5331f96 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 10 Oct 2025 16:06:20 -0600 Subject: [PATCH 09/64] test: Add cumsum tests to test-backend-ops Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ae45bc244d3c9..7cbff7580da93 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4771,6 +4771,31 @@ struct test_sum_rows : public test_case { } }; +// GGML_OP_CUMSUM +struct test_cumsum : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cumsum(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_cumsum(ctx, a); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_MEAN struct test_mean : public test_case { const ggml_type type; @@ -6871,6 +6896,8 @@ static std::vector> make_test_cases_eval(bool verbose test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 13, 15, 26, 15 })); for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); From 6949ce7b17f3757e957cb143695e3ee3a3f6eea5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 10 Oct 2025 16:38:40 -0600 Subject: [PATCH 10/64] feat(ggml-cpu): Add cumsum support for f16 and bf16 Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 70 ++++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/vec.h | 60 ++++++++++++++++++++++++-------- tests/test-backend-ops.cpp | 12 +++++-- 3 files changed, 125 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9af57b4d3b695..d37f5ad406201 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -1428,6 +1428,68 @@ static void ggml_compute_forward_cumsum_f32( } } +static void ggml_compute_forward_cumsum_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(dst->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + ggml_fp16_t * src_row = (ggml_fp16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + ggml_fp16_t * dst_row = (ggml_fp16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f16(ne00, dst_row, src_row); + } + } + } +} + +static void ggml_compute_forward_cumsum_bf16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_bf16_t)); + GGML_ASSERT(dst->nb[0] == sizeof(ggml_bf16_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + ggml_bf16_t * src_row = (ggml_bf16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + ggml_bf16_t * dst_row = (ggml_bf16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_bf16(ne00, dst_row, src_row); + } + } + } +} + void ggml_compute_forward_cumsum( const ggml_compute_params * params, ggml_tensor * dst) { @@ -1439,6 +1501,14 @@ void ggml_compute_forward_cumsum( { ggml_compute_forward_cumsum_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_cumsum_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_cumsum_bf16(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3e60f9e3d9020..574df9990ebf1 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1402,6 +1402,8 @@ inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const } } +// sum + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; @@ -1414,6 +1416,32 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #endif } +inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +} + +inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_CPU_FP16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_BF16_TO_FP32(x[i]); + } + *s = sum; +} + +// tri + // Applies a triangular mask to the input vector 'src' and writes the result to 'dst'. // Parameters: // n - number of elements @@ -1436,6 +1464,8 @@ inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const } } +// cumsum + inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) { if (i == 0) { @@ -1446,29 +1476,27 @@ inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) } } -inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { - ggml_float sum = 0.0; +inline static void ggml_vec_cumsum_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { - sum += (ggml_float)x[i]; + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i - 1]) + GGML_CPU_FP16_TO_FP32(x[i])); + } } - *s = sum; } -inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { - float sum = 0.0f; +inline static void ggml_vec_cumsum_bf16(const int n, ggml_bf16_t * y, const ggml_bf16_t * x) { for (int i = 0; i < n; ++i) { - sum += GGML_CPU_FP16_TO_FP32(x[i]); + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(y[i - 1]) + GGML_BF16_TO_FP32(x[i])); + } } - *s = sum; } -inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { - float sum = 0.0f; - for (int i = 0; i < n; ++i) { - sum += GGML_BF16_TO_FP32(x[i]); - } - *s = sum; -} +// max inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE @@ -1482,6 +1510,8 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #endif } +// norm inv + inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1.f/(*s); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7cbff7580da93..a7243b2194d03 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6896,8 +6896,11 @@ static std::vector> make_test_cases_eval(bool verbose test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 13, 15, 26, 15 })); + + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 15, 26, 12 })); for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); @@ -7055,6 +7058,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true)); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 15, 26, 12 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { From f8fba60e8f407607d58995a61ae49bd9ddaa06d6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 15:05:36 -0600 Subject: [PATCH 11/64] feat(ggml-cpu): Add F16 and BF16 support for tri Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 64 +++++++++++++++++++++++++++++++++++++-- ggml/src/ggml-cpu/vec.h | 31 ++++++++++++++----- 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d37f5ad406201..b14a07e5bc3a5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2283,13 +2283,65 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm const int64_t i02 = (ir - i03*ne02*ne01)/ne01; const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + float * dst_ptr = (float *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src = (float *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype); } } +static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = *((float *) &(dst->op_params[1])); + bool keep_org_val = isnan(c); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f16(ne0, i01, dst_ptr, src, keep_org_val, GGML_FP32_TO_FP16(c), ttype); + } + +} + +static void ggml_compute_forward_tri_bf16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + float c = *((float *) &(dst->op_params[1])); + bool keep_org_val = isnan(c); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_bf16_t * src = (ggml_bf16_t *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_bf16(ne0, i01, dst_ptr, src, keep_org_val, GGML_FP32_TO_BF16(c), ttype); + } + +} + void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; @@ -2298,6 +2350,14 @@ void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * { ggml_compute_forward_tri_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_tri_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_tri_bf16(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 574df9990ebf1..b12a07233c9ce 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1451,16 +1451,31 @@ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16 // keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c' // c - constant value to use when not keeping original value // type - type of triangular mask (lower, upper, etc.) +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const enum ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; break; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; break; + case GGML_TRI_TYPE_UPPER: return i > r; break; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; break; + } +} + inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) { for (int i = 0; i < n; ++i) { - bool cmp; - switch (type) { - case GGML_TRI_TYPE_LOWER: cmp = i < r; break; - case GGML_TRI_TYPE_LOWER_DIAG: cmp = i <= r; break; - case GGML_TRI_TYPE_UPPER: cmp = i > r; break; - case GGML_TRI_TYPE_UPPER_DIAG: cmp = i >= r; break; - } - dst[i] = cmp ? (keep_org_val ? src[i] : c) : 0.0f; + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : 0.0f; + } +} + +inline static void ggml_vec_tri_f16(const int n, const int r, ggml_fp16_t * dst, const ggml_fp16_t * src, bool keep_org_val, ggml_fp16_t c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : 0; + } +} + +inline static void ggml_vec_tri_bf16(const int n, const int r, ggml_bf16_t * dst, const ggml_bf16_t * src, bool keep_org_val, ggml_bf16_t c, enum ggml_tri_type type) { + const ggml_bf16_t zero = ggml_fp32_to_bf16(0); + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : zero; } } From 058160a42d2e6cecac5b847c95fe2d57aa486a1d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 15:05:48 -0600 Subject: [PATCH 12/64] test: Add test cases for tri Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 51 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a7243b2194d03..cee248e21090e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4796,6 +4796,35 @@ struct test_cumsum : public test_case { } }; +// GGML_OP_TRI +struct test_tri : public test_case { + const ggml_type type; + const std::array ne; + const ggml_tri_type tri_type; + const float c; + + std::string vars() override { + return VARS_TO_STR4(type, ne, tri_type, c); + } + + test_tri(ggml_tri_type tri_type, + ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 1, 1}, + float c = nan("")) + : type(type), ne(ne), tri_type(tri_type), c(c) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_tri(ctx, a, c, tri_type); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_MEAN struct test_mean : public test_case { const ggml_type type; @@ -6902,6 +6931,17 @@ static std::vector> make_test_cases_eval(bool verbose test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 15, 26, 12 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_BF16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16})); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); @@ -7063,6 +7103,17 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 15, 26, 12 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_BF16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16})); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { From 86ce3da9ac3ae9773f15b1aaa410e113f401a500 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 16:50:11 -0600 Subject: [PATCH 13/64] chore: TODOs to loosen assertions in tri for ggml_is_contiguous Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b14a07e5bc3a5..2891a48895d88 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2271,6 +2271,7 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm float c = *((float *) &(dst->op_params[1])); bool keep_org_val = isnan(c); + // TODO: Is ggml_is_contiguous_rows safe and sufficient? GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->ne[0] == src0->ne[1]); @@ -2297,6 +2298,7 @@ static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggm const float c = *((float *) &(dst->op_params[1])); bool keep_org_val = isnan(c); + // TODO: Is ggml_is_contiguous_rows safe and sufficient? GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->ne[0] == src0->ne[1]); @@ -2323,6 +2325,7 @@ static void ggml_compute_forward_tri_bf16(const ggml_compute_params * params, gg float c = *((float *) &(dst->op_params[1])); bool keep_org_val = isnan(c); + // TODO: Is ggml_is_contiguous_rows safe and sufficient? GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->ne[0] == src0->ne[1]); From 3a8958f229ffe1c1961465a1193b4264839b35a6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 16:51:57 -0600 Subject: [PATCH 14/64] feat(ggml-metal): Initial (slow) implementation of cumsum for metal This should be using simd operations for better parallelism, but that will come next. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 25 ++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 3 + ggml/src/ggml-metal/ggml-metal-impl.h | 19 ++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 59 +++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 70 +++++++++++++++++++++++ 7 files changed, 178 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e23abdda97405..7ab81133a5a90 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -299,6 +299,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "cumsum"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + // shared memory buffer for a single simd group size + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1034e4bbf6596..35c4b4dab011a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -110,6 +110,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 9527973015245..982f75c94f577 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -656,6 +656,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_CUMSUM: + //DEBUG -- Refine this! + return true; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index c9dff87305869..495af32929989 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -563,6 +563,25 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_sum_rows; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cumsum; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5f9370449bb2d..dcd2c67f9ea33 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -306,6 +306,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_sum_rows(ctx, idx); } break; + case GGML_OP_CUMSUM: + { + n_fuse = ggml_metal_op_cumsum(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -903,6 +907,61 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_cumsum args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cumsum(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index d4cb9446212d9..7e24164aa2e26 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -51,6 +51,7 @@ int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ddc285042d284..92403b76d0c49 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1778,6 +1778,76 @@ typedef decltype(kernel_sum_rows) kernel_sum_rows_t; template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template +kernel void kernel_cumsum( + constant ggml_metal_kargs_cumsum & args, + device const char * src0, + device const char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + //DEBUG -- This is the _very_ neive version + dst_row[i0] = src_row[i0]; + for (int64_t j = 0; j < i0; ++j) { + dst_row[i0] = static_cast(static_cast(src_row[j]) + static_cast(dst_row[i0])); + } + } + + // if (sgitg == 0) { + // shmem_f32[tiisg] = 0.0f; + // } + + + // float sumf = 0; + + // for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + // sumf += src_row[i0]; + // } + + // sumf = simd_sum(sumf); + + // threadgroup_barrier(mem_flags::mem_threadgroup); + + // if (tiisg == 0) { + // shmem_f32[sgitg] = sumf; + // } + + // threadgroup_barrier(mem_flags::mem_threadgroup); + + // sumf = shmem_f32[tiisg]; + // sumf = simd_sum(sumf); + + // if (tpitg.x == 0) { + // dst_row[0] = norm ? sumf / args.ne00 : sumf; + // } +} + +typedef decltype(kernel_cumsum) kernel_cumsum_t; + +template [[host_name("kernel_cumsum_f32")]] kernel kernel_cumsum_t kernel_cumsum; +template [[host_name("kernel_cumsum_f16")]] kernel kernel_cumsum_t kernel_cumsum; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cumsum_bf16")]] kernel kernel_cumsum_t kernel_cumsum; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, From cbaed8653ae6fbcf7a0a8ae68ece2f5b779047dc Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 17:03:01 -0600 Subject: [PATCH 15/64] feat(ggml-metal): Add stubs for metal tri Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 20 ++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 11 +++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + 5 files changed, 34 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 7ab81133a5a90..69e431522b0d0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -324,6 +324,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_ return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "tri"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + return ggml_metal_library_compile_pipeline(lib, base, name, nullptr); +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 35c4b4dab011a..4f8e9c913af0b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -111,6 +111,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 982f75c94f577..1da22fbd44d1c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -657,6 +657,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CUMSUM: + case GGML_OP_TRI: //DEBUG -- Refine this! return true; case GGML_OP_SUM_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index dcd2c67f9ea33..1cd12d892662a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -310,6 +310,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_cumsum(ctx, idx); } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -962,6 +966,13 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + //DEBUG + GGML_ASSERT(false); + return 1; +} + int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 7e24164aa2e26..223a7d79c72e1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); From e5964695e27288d5e773f596f579532ec4355321 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 17:03:33 -0600 Subject: [PATCH 16/64] test: Use looser nmse for lower-precision types for cumsum Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cee248e21090e..0bf86eaf329d0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4784,6 +4784,15 @@ struct test_cumsum : public test_case { std::array ne = {10, 5, 4, 3}) : type(type), ne(ne) {} + + double max_nmse_err() override { + // Lower precision types have expected precision errors in lower bits + if (type == GGML_TYPE_BF16 || type == GGML_TYPE_F16) { + return 1e-5; + } + return 1e-7; + } + ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); From 112d339fe6130a06a625e3afffe62de074e6b5d2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 15 Oct 2025 07:53:35 -0600 Subject: [PATCH 17/64] test: Allow multiple verbose flags to fully print tensors Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 180158db9d71a..4684d179c057a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1167,8 +1167,8 @@ struct test_case { std::vector sentinels; - // set to true to print tensors - bool verbose = false; + // set to 1 to print tensors, 2 to fully print tensors + int verbose = 0; void add_sentinel(ggml_context * ctx) { if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { @@ -1315,7 +1315,7 @@ struct test_case { // compare struct callback_userdata { bool ok; - bool verbose; + int verbose; double max_err; ggml_backend_t backend1; ggml_backend_t backend2; @@ -1349,8 +1349,8 @@ struct test_case { } if (ud->verbose) { - ggml_print_tensor(t1); - ggml_print_tensor(t2); + ggml_print_tensor(t1, ud->verbose >= 2 ? 1e10 : 3); + ggml_print_tensor(t2, ud->verbose >= 2 ? 1e10 : 3); } std::vector f1 = tensor_to_float(t1); @@ -5953,7 +5953,7 @@ static const ggml_type other_types[] = { }; // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low -static std::vector> make_test_cases_eval(bool verbose = false) { +static std::vector> make_test_cases_eval(int verbose = 0) { std::vector> test_cases; std::default_random_engine rng(0); @@ -6938,7 +6938,8 @@ static std::vector> make_test_cases_eval(bool verbose test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 15, 26, 12 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 4, 2, 1 })); + // test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); @@ -7198,7 +7199,7 @@ static std::vector> make_test_cases_perf() { } static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter, - printer * output_printer, bool verbose) { + printer * output_printer, int verbose) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -7379,7 +7380,7 @@ static void usage(char ** argv) { printf(" --output specifies output format (default: console, options: console, sql, csv)\n"); printf(" --list-ops lists all available GGML operations\n"); printf(" --show-coverage shows test coverage\n"); - printf(" --verbose | -v print tensors during ops\n"); + printf(" --verbose | -v print tensors during ops (can specify multiple times)\n"); } int main(int argc, char ** argv) { @@ -7388,7 +7389,7 @@ int main(int argc, char ** argv) { const char * op_names_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; - bool verbose = false; + int verbose = 0; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "test") == 0) { @@ -7437,7 +7438,7 @@ int main(int argc, char ** argv) { show_test_coverage(); return 0; } else if (strcmp(argv[i], "--verbose") == 0 || strcmp(argv[i], "-v") == 0) { - verbose = true; + ++verbose; } else { usage(argv); return 1; From 78e137ffe63174bfe402238e08f2cc8716151ad0 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 26 Sep 2025 16:43:39 -0600 Subject: [PATCH 18/64] feat(llama-gguf): Print out the tensor type in llama-gguf r Branch: Mamba2Perf Signed-off-by: Gabe Goodhart --- examples/gguf/gguf.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index f31989c8c55c6..1bf8e705e359c 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -184,8 +184,9 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { const char * name = gguf_get_tensor_name (ctx, i); const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); + const char * type = ggml_type_name(gguf_get_tensor_type(ctx, i)); - printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s\n", __func__, i, name, size, offset, type); } } From e5587cb156ab3b16ab6ae2f8509dd6847ca7cbbd Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 15 Oct 2025 08:30:38 -0600 Subject: [PATCH 19/64] feat(ggml-metal): Efficient implementation of cumsum for metal Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 13 +++--- ggml/src/ggml-metal/ggml-metal.metal | 53 ++++++++++------------- 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index c470bc9dedd31..41519ba4e51a2 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -330,15 +330,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_ snprintf(name, 256, "%s", base); + // reuse existing precompiled pipeline, but allow memory size setting ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + if (!res) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - // shared memory buffer for a single simd group size - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + // one shared memory element for each simd group in the threadgroup + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + const int nsg = (ne00 + 31)/32; + ggml_metal_pipeline_set_smem(res, nsg*sizeof(float)); return res; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5e617c0152acf..150e2b4294e56 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1801,7 +1801,7 @@ kernel void kernel_cumsum( constant ggml_metal_kargs_cumsum & args, device const char * src0, device const char * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + threadgroup float * shmem_f32 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], @@ -1822,40 +1822,31 @@ kernel void kernel_cumsum( // threadgroup, so this will loop once for each index that this thread is // responsible for for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - //DEBUG -- This is the _very_ neive version - dst_row[i0] = src_row[i0]; - for (int64_t j = 0; j < i0; ++j) { - dst_row[i0] = static_cast(static_cast(src_row[j]) + static_cast(dst_row[i0])); - } - } - - // if (sgitg == 0) { - // shmem_f32[tiisg] = 0.0f; - // } - - - // float sumf = 0; - - // for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - // sumf += src_row[i0]; - // } - // sumf = simd_sum(sumf); + // Each thread does simd_prefix_inclusive_sum => every element of row + // now holds cumsum of the simd group + float sumf = static_cast(src_row[i0]); + sumf = simd_prefix_inclusive_sum(sumf); + dst_row[i0] = static_cast(sumf); - // threadgroup_barrier(mem_flags::mem_threadgroup); - - // if (tiisg == 0) { - // shmem_f32[sgitg] = sumf; - // } - - // threadgroup_barrier(mem_flags::mem_threadgroup); + // If this is the last element of the simd group, store its value in + // shared memory + if (tiisg == N_SIMDWIDTH - 1 || i0 == args.ne00 - 1) { + const ushort shmem_idx = i0 / N_SIMDWIDTH; + shmem_f32[shmem_idx] = sumf; + } + } - // sumf = shmem_f32[tiisg]; - // sumf = simd_sum(sumf); + // Ensure all simd groups sync here before proceeding + threadgroup_barrier(mem_flags::mem_threadgroup); - // if (tpitg.x == 0) { - // dst_row[0] = norm ? sumf / args.ne00 : sumf; - // } + // Each element then adds the final value of all preceding simd groups + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + const ushort shmem_idx = i0 / N_SIMDWIDTH; + for (ushort j = 0; j < shmem_idx; ++j) { + dst_row[i0] += static_cast(shmem_f32[j]); + } + } } typedef decltype(kernel_cumsum) kernel_cumsum_t; From 0468b99e28809b11ac134aa62f32b25979b78957 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 15 Oct 2025 08:32:17 -0600 Subject: [PATCH 20/64] test: More verbose printing and better cumsum tests Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4684d179c057a..c20d652223c1a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1382,11 +1382,12 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); - //exit(1); + if (ud->verbose) { + for (int i = 0; i < (int) f1.size(); i++) { + printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + } + printf("\n"); + } ud->ok = false; } return true; @@ -6938,8 +6939,7 @@ static std::vector> make_test_cases_eval(int verbose test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 4, 2, 1 })); - // test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); @@ -7111,7 +7111,7 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 15, 26, 12 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); From c71e35ecfe2d6109665bf32c51b7108b6750eb9d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 15 Oct 2025 08:33:07 -0600 Subject: [PATCH 21/64] fix(ggml-metal): better granularity for support bool for CUMSUM and TRI Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.m | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 5c5ce02b575ce..b63ddd861daae 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -656,10 +656,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_CUMSUM: case GGML_OP_TRI: - //DEBUG -- Refine this! - return true; + return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_CUMSUM: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: From 5f0d2a1ed6f19852369ff2ba801c03fd1a192cf7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 15 Oct 2025 10:16:57 -0600 Subject: [PATCH 22/64] feat(ggml-metal): Metal impl of tri Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 53 +++++++++++++++++++++++++- ggml/src/ggml-metal/ggml-metal.metal | 50 ++++++++++++++++++++++++ tests/test-backend-ops.cpp | 2 + 4 files changed, 124 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 11cd056117ea7..5b0d5ad7cf05c 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -586,6 +586,27 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_cumsum; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float c; + uint32_t ttype; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 42fc2ca63c4ea..1209ad7a1637e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1004,8 +1004,57 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); - //DEBUG - GGML_ASSERT(false); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; + const float c = *((float *) &(op->op_params[1])); + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.c =*/ c, + /*.ttype =*/ static_cast(ttype) + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 150e2b4294e56..c8c081b2aa129 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1857,6 +1857,56 @@ template [[host_name("kernel_cumsum_f16")]] kernel kernel_cumsum_t kernel_cumsum template [[host_name("kernel_cumsum_bf16")]] kernel kernel_cumsum_t kernel_cumsum; #endif +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) { + switch (type) { + // ggml.h:620 + case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break; + case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break; + case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break; + case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break; + } +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + const bool keep_org_val = isnan(args.c); + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) + ? (keep_org_val ? src_row[i0] : static_cast(args.c)) + : static_cast(0.f); + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c20d652223c1a..02e028ee93866 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6951,6 +6951,7 @@ static std::vector> make_test_cases_eval(int verbose test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); @@ -7123,6 +7124,7 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { From ba3b8db3a2e73fe39c93ff05a49b99508a5c01d0 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 15 Oct 2025 10:33:50 -0600 Subject: [PATCH 23/64] fix(ggml-cpu): Fix warnings from build with gcc Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 18 +++++++++--------- ggml/src/ggml-cpu/vec.h | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ddcf0d41cfc68..4aaf08aa2ea6a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2267,9 +2267,9 @@ static void ggml_compute_forward_gelu( static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; - float c = *((float *) &(dst->op_params[1])); - bool keep_org_val = isnan(c); + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const bool keep_org_val = isnan(c); // TODO: Is ggml_is_contiguous_rows safe and sufficient? GGML_ASSERT(ggml_is_contiguous(src0)); @@ -2294,9 +2294,9 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; - const float c = *((float *) &(dst->op_params[1])); - bool keep_org_val = isnan(c); + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const bool keep_org_val = isnan(c); // TODO: Is ggml_is_contiguous_rows safe and sufficient? GGML_ASSERT(ggml_is_contiguous(src0)); @@ -2321,9 +2321,9 @@ static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggm static void ggml_compute_forward_tri_bf16(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; - float c = *((float *) &(dst->op_params[1])); - bool keep_org_val = isnan(c); + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const bool keep_org_val = isnan(c); // TODO: Is ggml_is_contiguous_rows safe and sufficient? GGML_ASSERT(ggml_is_contiguous(src0)); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 27681bee9e7c6..94031a1b01008 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1459,6 +1459,7 @@ inline static bool _ggml_vec_tri_cmp(const int i, const int r, const enum ggml_t case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; break; case GGML_TRI_TYPE_UPPER: return i > r; break; case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; break; + default: GGML_ABORT("Invalid tri type"); } } From dfae90926107a155be4d41ae347e6c979c611c7d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 16 Oct 2025 10:44:46 -0600 Subject: [PATCH 24/64] feat(ggml-cuda): common implementation of prefix sum Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/common.cuh | 50 +++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 41ff89c4d6922..199d8f9debc07 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -406,6 +406,56 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + if (lane_id >= offset) { + a += t; + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + template static __device__ __forceinline__ int warp_reduce_all(int x) { if (width == ggml_cuda_get_physical_warp_size()) { From d1f86582ed31649005c57ce886544571e2bb0f8e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 16 Oct 2025 14:13:59 -0600 Subject: [PATCH 25/64] feat(ggml-cuda): CUDA implementation of CUMSUM Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/cumsum.cu | 126 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/cumsum.cuh | 5 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++ 3 files changed, 136 insertions(+) create mode 100644 ggml/src/ggml-cuda/cumsum.cu create mode 100644 ggml/src/ggml-cuda/cumsum.cuh diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 0000000000000..e14be0721c699 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,126 @@ +#include "cumsum.cuh" + +// Kernel to compute cumulative sum along the innermost dimension (ne[0]) +// Each block processes one row (ne[0] elements) +// Algorithm matches Metal implementation: +// 1. Each warp computes prefix sum within itself +// 2. Last thread of each warp stores result in shared memory +// 3. All warps sync +// 4. Each element adds the sum of all preceding warps + +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + + // Shared memory to store warp sums (always use float for accumulation) + extern __shared__ float shmem[]; + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const int tid = threadIdx.x; + const int lane_id = tid % WARP_SIZE; + + // Phase 1: Each thread processes elements at stride blockDim.x + // Compute warp-level prefix sums + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + // Load value and compute prefix sum within warp + float val = static_cast(src_row[i0]); + val = warp_prefix_inclusive_sum(val); + dst_row[i0] = static_cast(val); + + // Last thread of warp stores its sum to shared memory at position based on data index + if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { + const int shmem_idx = i0 / WARP_SIZE; + shmem[shmem_idx] = val; + } + } + + // Sync once after all warp prefix sums are computed + __syncthreads(); + + // Phase 2: Add the sum of all preceding warp groups to each element + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + const int shmem_idx = i0 / WARP_SIZE; + float sum = 0.0f; + for (int j = 0; j < shmem_idx; ++j) { + sum += shmem[j]; + } + dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + // Shared memory size: one float per warp + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + const size_t shmem_size = num_warps * sizeof(float); + + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3 + ); +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 0000000000000..782d1d92e9bb1 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75fd6db14c514..84b26a7bb2118 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" @@ -2512,6 +2513,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -3650,6 +3654,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_CUMSUM: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; From 5071fbd5786558ab21127489069895a58c4bb838 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 16 Oct 2025 14:31:58 -0600 Subject: [PATCH 26/64] feat(ggml-cuda): CUDA implementation of TRI Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++ ggml/src/ggml-cuda/tri.cu | 109 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/tri.cuh | 5 ++ 3 files changed, 119 insertions(+) create mode 100644 ggml/src/ggml-cuda/tri.cu create mode 100644 ggml/src/ggml-cuda/tri.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 84b26a7bb2118..7e11108c9684f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -47,6 +47,7 @@ #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" +#include "ggml-cuda/tri.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" @@ -2516,6 +2517,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CUMSUM: ggml_cuda_op_cumsum(ctx, dst); break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -3655,6 +3659,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CUMSUM: + case GGML_OP_TRI: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 0000000000000..d9c4aa025dbaf --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,109 @@ +#include "tri.cuh" +#include "ggml.h" +#include + +// Triangle type comparison - determines which elements to keep +__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; + case GGML_TRI_TYPE_UPPER: return i > r; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; + default: return false; + } +} + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const float c, const ggml_tri_type ttype) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const bool keep_org_val = isnan(c); + + // Each thread processes elements at stride blockDim.x + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = tri_compare(i0, i1, ttype) + ? (keep_org_val ? src_row[i0] : static_cast(c)) + : static_cast(0.f); + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const float c, const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + c, ttype + ); +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + const float c = ggml_get_op_params_f32(dst, 1); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 0000000000000..a4cc66750d3b5 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From be23a291d91a434b87c881a8945c826db75834d9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 25 Sep 2025 16:57:31 -0600 Subject: [PATCH 27/64] test: Add test-backend-ops perf tests for ssm conv and scan Branch: Mamba2Perf Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f272a39e57a0d..0cdadf57919a7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1430,7 +1430,7 @@ struct test_case { ggml_tensor * out = build_graph(ctx.get()); std::string current_op_name = op_desc(out); if (!matches_filter(out, op_names_filter)) { - //printf(" %s: skipping\n", op_desc(out).c_str()); + // printf(" %s: skipping\n", op_desc(out).c_str()); return true; } @@ -7221,6 +7221,18 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it)); } + for (int64_t d_conv : {3, 4}) { + for (int64_t d_inner: {1024, 1536, 2048}) { + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); + } + } + + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 + return test_cases; } From 71e2289b7ecf8b564f4fb3b5d4cbe627af7abbfe Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 17 Oct 2025 15:41:47 -0600 Subject: [PATCH 28/64] feat(ggml-cpu): Rename ggml_softplus to ggml_op_softplus to make room for tensor op Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 4 ++-- ggml/src/ggml-impl.h | 2 +- ggml/src/ggml.c | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 4aaf08aa2ea6a..29fd2f7b8ac4f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8861,7 +8861,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = ggml_softplus(dt[h]); + const float dt_soft_plus = ggml_op_softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -8958,7 +8958,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = ggml_softplus(dt[h]); + const float dt_soft_plus = ggml_op_softplus(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index d0fb3bccad225..377ed2d215e7e 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } -static inline float ggml_softplus(float input) { +static inline float ggml_op_softplus(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } // diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7e4bfb07154b3..326be7f5ed67a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2682,8 +2682,8 @@ struct ggml_tensor * ggml_xielu( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); - ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); - ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 1, beta + ggml_op_softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_op_softplus(alpha_p)); ggml_set_op_params_f32(result, 3, beta); ggml_set_op_params_f32(result, 4, eps); From f6d60e39f373b4c28c0306b0deeff1fb127effc3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 17 Oct 2025 15:51:02 -0600 Subject: [PATCH 29/64] feat(ggml-cpu): Add ggml_softplus tensor op for CPU Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/include/ggml.h | 7 +++++++ ggml/src/ggml-cpu/ggml-cpu.c | 1 + ggml/src/ggml-cpu/ops.cpp | 4 ++++ ggml/src/ggml-cpu/unary-ops.cpp | 4 ++++ ggml/src/ggml-cpu/unary-ops.h | 1 + ggml/src/ggml.c | 11 ++++++++++- 6 files changed, 27 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index beb7ee988097a..173f6a0979734 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -579,6 +579,7 @@ extern "C" { GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_XIELU, + GGML_UNARY_OP_SOFTPLUS, GGML_UNARY_OP_COUNT, }; @@ -1164,6 +1165,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // softplus(x) = log(1 + exp(beta * x)) / beta, for x * beta <= threshold + // = x, otherwise + GGML_API struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a); + // xIELU activation function // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 9bd36ea8cacef..f044d7d3b443f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2203,6 +2203,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_SOFTPLUS: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 29fd2f7b8ac4f..b63e3a64e0b34 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9225,6 +9225,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_xielu(params, dst); } break; + case GGML_UNARY_OP_SOFTPLUS: + { + ggml_compute_forward_softplus(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index cf1a4615d042c..c27548f045671 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -287,3 +287,7 @@ void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor unary_op_functor(params, dst, xielu_op_params); } +void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index 697c1e0da0ace..69ffa13501d7e 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -23,6 +23,7 @@ void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 326be7f5ed67a..39b687509e7fc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1148,9 +1148,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "EXP", "GELU_ERF", "XIELU", + "SOFTPLUS", }; -static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); +static_assert(GGML_UNARY_OP_COUNT == 17, "GGML_UNARY_OP_COUNT != 17"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2738,6 +2739,14 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } +// ggml_softplus + +struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS); +} + // ggml_glu static struct ggml_tensor * ggml_glu_impl( From 778e835aa59fab0b4363cd4677afb8ce6e54df27 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 17 Oct 2025 15:51:47 -0600 Subject: [PATCH 30/64] test: Better verbosity output for inputs in test-backend-ops Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- tests/test-backend-ops.cpp | 95 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0cdadf57919a7..451e0e44ce7e1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -240,7 +241,8 @@ static std::string ggml_ne_string(const ggml_tensor * t) { return str; } -static void ggml_print_tensor(ggml_tensor * t, int64_t n = 3) { +static void ggml_print_tensor(ggml_tensor * t, int64_t verbose = 0) { + int n = verbose >= 2 ? std::numeric_limits::max() : 3; GGML_ASSERT(t != nullptr); GGML_ASSERT(n > 0); @@ -1137,6 +1139,9 @@ struct test_case { virtual void initialize_tensors(ggml_context * ctx) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -1349,8 +1354,8 @@ struct test_case { } if (ud->verbose) { - ggml_print_tensor(t1, ud->verbose >= 2 ? 1e10 : 3); - ggml_print_tensor(t2, ud->verbose >= 2 ? 1e10 : 3); + ggml_print_tensor(t1, ud->verbose); + ggml_print_tensor(t2, ud->verbose); } std::vector f1 = tensor_to_float(t1); @@ -1930,6 +1935,9 @@ struct test_unary : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -1995,6 +2003,9 @@ struct test_glu : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2053,6 +2064,9 @@ struct test_glu_split : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2113,6 +2127,9 @@ struct test_swiglu_oai : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2170,6 +2187,9 @@ struct test_get_rows : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2223,6 +2243,9 @@ struct test_get_rows_back : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2304,6 +2327,9 @@ struct test_set_rows : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -2369,6 +2395,9 @@ struct test_argmax : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -2430,6 +2459,9 @@ struct test_count_equal : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2684,6 +2716,9 @@ struct test_cpy : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check if casting between f32 and i32 is consistent init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2782,6 +2817,9 @@ struct test_bin_bcast : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -2856,6 +2894,9 @@ struct test_add_id : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -3116,6 +3157,9 @@ struct test_rms_norm : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -10.f, 10.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3159,6 +3203,9 @@ struct test_rms_norm_back : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -10.f, 10.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -3215,6 +3262,9 @@ struct test_rms_norm_mul_add : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -10.f, 10.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3305,6 +3355,9 @@ struct test_ssm_scan : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -3598,6 +3651,9 @@ struct test_mul_mat_id : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3712,6 +3768,9 @@ struct test_sqrt : public test_case { // fill with positive values for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, 50.0f, 100.0f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3752,6 +3811,9 @@ struct test_log : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass: init_tensor_uniform(t, 0.9f, 1.1f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3787,6 +3849,9 @@ struct test_sin : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi]. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3830,6 +3895,9 @@ struct test_cos : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi]. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -4132,6 +4200,9 @@ struct test_rope : public test_case { init_tensor_uniform(t); } } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -4659,6 +4730,9 @@ struct test_argsort : public test_case { } else { GGML_ABORT("fatal error"); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -5356,6 +5430,9 @@ struct test_flash_attn_ext : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5400,6 +5477,9 @@ struct test_cross_entropy_loss : public test_case { // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients. for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -100.0f, 100.0f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5485,6 +5565,9 @@ struct test_opt_step_adamw : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5524,6 +5607,9 @@ struct test_opt_step_sgd : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, 0.0f, 1.0f); // sgd_params need non-negative values. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5666,6 +5752,9 @@ struct test_llm : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; From 4228002db15838b07b638155bd791bb2483c4216 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 17 Oct 2025 15:58:52 -0600 Subject: [PATCH 31/64] feat(ggml-metal): Add ggml_softplus support for metal Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++++++++++ 3 files changed, 18 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 41519ba4e51a2..8581d61ba00eb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -211,6 +211,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; case GGML_UNARY_OP_EXP: op_str = "exp"; break; + case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index e93ad6534e2cc..9a5cd0f56d970 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -619,6 +619,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e4c9579981dba..43ecf8bdc066d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1398,6 +1398,22 @@ kernel void kernel_silu_f32_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_softplus_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 20.0f) ? x : log(1.0f + exp(x)); +} + +kernel void kernel_softplus_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + kernel void kernel_elu_f32( device const float * src0, device float * dst, From 97bd17dca1514d0f45987f4e55495429df0919d3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 17 Oct 2025 16:11:21 -0600 Subject: [PATCH 32/64] feat(ggml-cuda): Add support for ggml_softplus Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++++ ggml/src/ggml-cuda/unary.cu | 8 ++++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ 3 files changed, 14 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7e11108c9684f..c904bf5139432 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2345,6 +2345,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_softplus(ctx, dst); + break; case GGML_UNARY_OP_XIELU: ggml_cuda_op_xielu(ctx, dst); break; @@ -3365,6 +3368,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_SOFTPLUS: return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 3c564566a51ff..6b602beffca9f 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -88,6 +88,10 @@ static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } +static __device__ __forceinline__ float op_softplus(float x) { + return (x > 20.0f) ? x : logf(1.0f + expf(x)); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -204,6 +208,10 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } + +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 8e7644fcd9a48..b696b91e1382b 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -62,6 +62,8 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From ffd88ff8188bbc05ae2041c4d303010bca6d1823 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 20 Oct 2025 16:59:34 -0600 Subject: [PATCH 33/64] style: comments on ggml tri types Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/include/ggml.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 173f6a0979734..ae17244e78fe9 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -619,10 +619,10 @@ extern "C" { }; enum ggml_tri_type { - GGML_TRI_TYPE_UPPER_DIAG = 0, - GGML_TRI_TYPE_UPPER = 1, - GGML_TRI_TYPE_LOWER_DIAG = 2, - GGML_TRI_TYPE_LOWER = 3 + GGML_TRI_TYPE_UPPER_DIAG = 0, // upper including diag + GGML_TRI_TYPE_UPPER = 1, // upper excluding diag + GGML_TRI_TYPE_LOWER_DIAG = 2, // lower including diag + GGML_TRI_TYPE_LOWER = 3 // lower excluding diag }; struct ggml_init_params { From 7409d9e5c4eae4490dc0e7baab77e87e26e33c68 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 20 Oct 2025 17:00:16 -0600 Subject: [PATCH 34/64] WIP(llama-model): Partial work on graph-based SSD implementation The chunk slicing and matrix permutations _should_ be right... Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 89 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5002bd42ff04e..8f6ba6c71f89a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11852,6 +11852,95 @@ struct llm_graph_context_mamba : public llm_graph_context { auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + if (n_seq_tokens == 1) { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); + // If single-token, use ssm_scan op + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + } else { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); + + // otherwise, use the SSD formulation + + // TODO: make this configurable + const uint32_t chunk_size = 256; + + // step 1: compute dt softplus + // NOTE: In other implementations, the bias is added after + // the softplus. This shouldn't be a problem, but it's a + // difference. + ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + + // step 2: compute dtA and dtX + ggml_tensor * dtA = ggml_mul(ctx, dt, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt, 1, dt->ne[0], dt->ne[1], dt->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + + // loop over all chunks + uint32_t repeats = n_head / n_group; + for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { + + // chunk views + const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i * chunk_size)); + // slice dtA on dim 1 + ggml_tensor * dtA_chunk = ggml_view_3d(ctx, dtA, + dtA->ne[0], chunk_size_i, dtA->ne[2], + dtA->nb[1], dtA->nb[2], + chunk_i * dtA->nb[1]); + // slice dtX on dim 2 + ggml_tensor * dtX_chunk = ggml_view_4d(ctx, dtX, + dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], + dtX->nb[1], dtX->nb[2], dtX->nb[3], + chunk_i * dtX->nb[2]); + // slice B on dim 2 + ggml_tensor * B_chunk = ggml_view_4d(ctx, B, + B->ne[0], B->ne[1], chunk_size_i, B->ne[3], + B->nb[1], B->nb[2], B->nb[3], + chunk_i * B->nb[2]); + // slice C on dim 2 + ggml_tensor * C_chunk = ggml_view_4d(ctx, C, + C->ne[0], C->ne[1], chunk_size_i, C->ne[3], + C->nb[1], C->nb[2], C->nb[3], + chunk_i * C->nb[2]); + + // step 3: compute CB + ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, C_perm, B_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} + + // step 4: compute decay + ggml_tensor * dtA_tmp0 = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, n_seq_tokens n_head, n_seqs} + ggml_tensor * dtA_tmp1 = ggml_repeat_4d(ctx, dtA_tmp0, + dtA_tmp0->ne[0] * chunk_size_i, dtA_tmp0->ne[1], dtA_tmp0->ne[2], dtA_tmp0->ne[3]); // {n_seq_tokens, n_seq_tokens n_head, n_seqs} + ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} + ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + decay = ggml_permute(ctx, segsum, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} + + // step 5: compute surrogate_attention_matrix + + // step 6: compute y + + // step 7: compute dtxdecay + + // step 8: compute next_state + + // update previous state if present + if (true) { + // step 9: compute exp_dtA_cumsum + + // step 10: compute y_prev + + // step 11: update y from y_prev + } + } + + //DEBUG + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + } + // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); From ba74006ef9a7937667cfb6845981add0410f77c9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 21 Oct 2025 16:41:16 -0600 Subject: [PATCH 35/64] TEMP: Increase the max graph nodes to handle all the nodes for SSD This is definitely an indication that the current implementation is bloated with view/permute operations! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e7526e7d0a557..381c154fc0924 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1362,7 +1362,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - return std::max(1024u, 8u*model.n_tensors()); + return std::max(1024u, 16u*model.n_tensors()); //DEBUG!!!!!!!!! } llm_graph_result * llama_context::get_gf_res_reserve() const { From 29b30c6f9ac4aded90ad5837a4dc79c8c0b9e208 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 21 Oct 2025 16:44:29 -0600 Subject: [PATCH 36/64] WIP: Shape-correct impl of SSD w/out multi-chunk support This is broken in a few ways currently: 1. It only supports a single chunk w/out a previous state 2. The output is incorrect (the obvious one) 3. There are way too many graph nodes necessitating an increas in the max nodes heuristic Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8f6ba6c71f89a..8bc0f619c8748 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11910,22 +11910,46 @@ struct llm_graph_context_mamba : public llm_graph_context { CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} // step 4: compute decay - ggml_tensor * dtA_tmp0 = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, n_seq_tokens n_head, n_seqs} + ggml_tensor * dtA_tmp0 = ggml_cont(ctx, ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0)); // {1, n_seq_tokens n_head, n_seqs} ggml_tensor * dtA_tmp1 = ggml_repeat_4d(ctx, dtA_tmp0, dtA_tmp0->ne[0] * chunk_size_i, dtA_tmp0->ne[1], dtA_tmp0->ne[2], dtA_tmp0->ne[3]); // {n_seq_tokens, n_seq_tokens n_head, n_seqs} ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} decay = ggml_permute(ctx, segsum, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} // step 5: compute surrogate_attention_matrix + ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER); // step 6: compute y + ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) + // ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 2, 1, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) + ggml_tensor * y = ggml_mul_mat(ctx, surrogate_attention_matrix, dtX_chunk_perm); // step 7: compute dtxdecay + ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + decay->ne[0], 1, decay->ne[2], decay->ne[3], + decay->nb[1], decay->nb[2], decay->nb[3], + (decay->ne[1] - 1) * decay->nb[1]); + decay_last = ggml_permute(ctx, decay_last, 2, 0, 1, 3); + B_perm = ggml_cont(ctx, B_perm); + B_perm = ggml_repeat_4d(ctx, B_perm, + B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); + dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); // step 8: compute next_state + ggml_tensor * next_state = ggml_mul_mat(ctx, dtxdecay, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3))); + next_state = ggml_permute(ctx, next_state, 1, 0, 2, 3); + + //DEBUG -- Single chunk w/out prev state + ggml_tensor * out = ggml_concat(ctx, + ggml_view_1d(ctx, y, ggml_nelements(y), 0), + ggml_view_1d(ctx, next_state, ggml_nelements(next_state), 0), + 0); + return out; // update previous state if present if (true) { From fb6896771d9ef5ec396e9726d98cc60356ebb896 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 23 Oct 2025 08:52:54 -0600 Subject: [PATCH 37/64] fix: Add names to tensors for better debugging and fix several wiring bugs Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8bc0f619c8748..7e02dac8c2a20 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11871,10 +11871,13 @@ struct llm_graph_context_mamba : public llm_graph_context { // the softplus. This shouldn't be a problem, but it's a // difference. ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + ggml_set_name(dt_softplus, "dt_softplus"); // step 2: compute dtA and dtX - ggml_tensor * dtA = ggml_mul(ctx, dt, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} - ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt, 1, dt->ne[0], dt->ne[1], dt->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + ggml_set_name(dtA, "dtA"); + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + ggml_set_name(dtX, "dtX"); // loop over all chunks uint32_t repeats = n_head / n_group; @@ -11887,27 +11890,32 @@ struct llm_graph_context_mamba : public llm_graph_context { dtA->ne[0], chunk_size_i, dtA->ne[2], dtA->nb[1], dtA->nb[2], chunk_i * dtA->nb[1]); + ggml_set_name(dtA_chunk, "dtA_chunk"); // slice dtX on dim 2 ggml_tensor * dtX_chunk = ggml_view_4d(ctx, dtX, dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], dtX->nb[1], dtX->nb[2], dtX->nb[3], chunk_i * dtX->nb[2]); + ggml_set_name(dtX_chunk, "dtX_chunk"); // slice B on dim 2 ggml_tensor * B_chunk = ggml_view_4d(ctx, B, B->ne[0], B->ne[1], chunk_size_i, B->ne[3], B->nb[1], B->nb[2], B->nb[3], chunk_i * B->nb[2]); + ggml_set_name(B_chunk, "B_chunk"); // slice C on dim 2 ggml_tensor * C_chunk = ggml_view_4d(ctx, C, C->ne[0], C->ne[1], chunk_size_i, C->ne[3], C->nb[1], C->nb[2], C->nb[3], chunk_i * C->nb[2]); + ggml_set_name(C_chunk, "C_chunk"); // step 3: compute CB ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} ggml_tensor * CB = ggml_mul_mat(ctx, C_perm, B_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} + ggml_set_name(CB, "CB"); // step 4: compute decay ggml_tensor * dtA_tmp0 = ggml_cont(ctx, ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0)); // {1, n_seq_tokens n_head, n_seqs} @@ -11916,17 +11924,21 @@ struct llm_graph_context_mamba : public llm_graph_context { ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_set_name(segsum, "segsum"); ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - decay = ggml_permute(ctx, segsum, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} + decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} + ggml_set_name(decay, "decay"); // step 5: compute surrogate_attention_matrix ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); - ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER); + ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + ggml_set_name(surrogate_attention_matrix, "surrogate_attention_matrix"); // step 6: compute y ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) // ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 2, 1, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) ggml_tensor * y = ggml_mul_mat(ctx, surrogate_attention_matrix, dtX_chunk_perm); + ggml_set_name(y, "y"); // step 7: compute dtxdecay ggml_tensor * decay_last = ggml_view_4d(ctx, decay, @@ -11939,10 +11951,12 @@ struct llm_graph_context_mamba : public llm_graph_context { B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + ggml_set_name(dtxdecay, "dtxdecay"); // step 8: compute next_state ggml_tensor * next_state = ggml_mul_mat(ctx, dtxdecay, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3))); next_state = ggml_permute(ctx, next_state, 1, 0, 2, 3); + ggml_set_name(next_state, "next_state"); //DEBUG -- Single chunk w/out prev state ggml_tensor * out = ggml_concat(ctx, From cd73f4d07ca6da6e0a26a469ba194d92b6589aa7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 23 Oct 2025 16:29:52 -0600 Subject: [PATCH 38/64] fix(wip): Fix matmul order for CB and y Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 81 +++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7e02dac8c2a20..0c5461130551c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11838,24 +11838,30 @@ struct llm_graph_context_mamba : public llm_graph_context { { // These correspond to V K Q in SSM/attention duality ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + cb(x, "x", il); ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + cb(B, "B", il); ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + cb(C, "C", il); // {n_head, n_seq_tokens, n_seqs} + cb(dt, "dt", il); dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + cb(dt, "dt_b", il); ggml_tensor * A = model.layers[il].ssm_a; + cb(A, "A", il); // use the states and the indices provided by build_recurrent_state // (this is necessary in order to properly use the states before they are overwritten, // while avoiding to make unnecessary copies of the states) auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { - ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); if (n_seq_tokens == 1) { //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); // If single-token, use ssm_scan op + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); } else { //DEBUG @@ -11871,13 +11877,14 @@ struct llm_graph_context_mamba : public llm_graph_context { // the softplus. This shouldn't be a problem, but it's a // difference. ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} - ggml_set_name(dt_softplus, "dt_softplus"); + dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); + cb(dt_softplus, "dt_softplus", il); // step 2: compute dtA and dtX - ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} - ggml_set_name(dtA, "dtA"); - ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} - ggml_set_name(dtX, "dtX"); + /* !! */ ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + cb(dtA, "dtA", il); + /* !! */ ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + cb(dtX, "dtX", il); // loop over all chunks uint32_t repeats = n_head / n_group; @@ -11890,32 +11897,32 @@ struct llm_graph_context_mamba : public llm_graph_context { dtA->ne[0], chunk_size_i, dtA->ne[2], dtA->nb[1], dtA->nb[2], chunk_i * dtA->nb[1]); - ggml_set_name(dtA_chunk, "dtA_chunk"); + cb(dtA_chunk, "dtA_chunk", il); // slice dtX on dim 2 ggml_tensor * dtX_chunk = ggml_view_4d(ctx, dtX, dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], dtX->nb[1], dtX->nb[2], dtX->nb[3], chunk_i * dtX->nb[2]); - ggml_set_name(dtX_chunk, "dtX_chunk"); + cb(dtX_chunk, "dtX_chunk", il); // slice B on dim 2 ggml_tensor * B_chunk = ggml_view_4d(ctx, B, B->ne[0], B->ne[1], chunk_size_i, B->ne[3], B->nb[1], B->nb[2], B->nb[3], chunk_i * B->nb[2]); - ggml_set_name(B_chunk, "B_chunk"); + cb(B_chunk, "B_chunk", il); // slice C on dim 2 ggml_tensor * C_chunk = ggml_view_4d(ctx, C, C->ne[0], C->ne[1], chunk_size_i, C->ne[3], C->nb[1], C->nb[2], C->nb[3], chunk_i * C->nb[2]); - ggml_set_name(C_chunk, "C_chunk"); + cb(C_chunk, "C_chunk", il); // step 3: compute CB ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} - ggml_tensor * CB = ggml_mul_mat(ctx, C_perm, B_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} + /* !! */ ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} - ggml_set_name(CB, "CB"); + cb(CB, "CB", il); // step 4: compute decay ggml_tensor * dtA_tmp0 = ggml_cont(ctx, ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0)); // {1, n_seq_tokens n_head, n_seqs} @@ -11923,22 +11930,22 @@ struct llm_graph_context_mamba : public llm_graph_context { dtA_tmp0->ne[0] * chunk_size_i, dtA_tmp0->ne[1], dtA_tmp0->ne[2], dtA_tmp0->ne[3]); // {n_seq_tokens, n_seq_tokens n_head, n_seqs} ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - ggml_set_name(segsum, "segsum"); - ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + /* !! */ ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + cb(segsum, "segsum", il); + /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} - ggml_set_name(decay, "decay"); + cb(decay, "decay", il); // step 5: compute surrogate_attention_matrix - ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + /* !! */ ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); - ggml_set_name(surrogate_attention_matrix, "surrogate_attention_matrix"); + cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); // step 6: compute y ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) - // ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 2, 1, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) - ggml_tensor * y = ggml_mul_mat(ctx, surrogate_attention_matrix, dtX_chunk_perm); - ggml_set_name(y, "y"); + /* !! */ ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); + cb(y, "y", il); // step 7: compute dtxdecay ggml_tensor * decay_last = ggml_view_4d(ctx, decay, @@ -11949,14 +11956,14 @@ struct llm_graph_context_mamba : public llm_graph_context { B_perm = ggml_cont(ctx, B_perm); B_perm = ggml_repeat_4d(ctx, B_perm, B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); - ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); + /* !! */ ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); - ggml_set_name(dtxdecay, "dtxdecay"); + cb(dtxdecay, "dtxdecay", il); // step 8: compute next_state - ggml_tensor * next_state = ggml_mul_mat(ctx, dtxdecay, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3))); + /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, dtxdecay, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3))); next_state = ggml_permute(ctx, next_state, 1, 0, 2, 3); - ggml_set_name(next_state, "next_state"); + cb(next_state, "next_state", il); //DEBUG -- Single chunk w/out prev state ggml_tensor * out = ggml_concat(ctx, @@ -11965,34 +11972,30 @@ struct llm_graph_context_mamba : public llm_graph_context { 0); return out; - // update previous state if present - if (true) { - // step 9: compute exp_dtA_cumsum + // // update previous state if present + // if (true) { + // // step 9: compute exp_dtA_cumsum - // step 10: compute y_prev + // // step 10: compute y_prev - // step 11: update y from y_prev - } + // // step 11: update y from y_prev + // } } - - //DEBUG - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); } - - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states + ggml_tensor * ssm_state = ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]); + cb(ssm_state, "ssm-state", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ssm_state, ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + cb(y, "y-reshaped", il); // TODO: skip computing output earlier for unused tokens From 52be1ab6e903ca1123bf7202e4374cc9ce70d44d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 23 Oct 2025 16:36:20 -0600 Subject: [PATCH 39/64] fix: Working output!! One more backwards mulmat Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0c5461130551c..652d1a9a8a752 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11858,6 +11858,7 @@ struct llm_graph_context_mamba : public llm_graph_context { auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { if (n_seq_tokens == 1) { + // if (true) { //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); // If single-token, use ssm_scan op @@ -11961,8 +11962,7 @@ struct llm_graph_context_mamba : public llm_graph_context { cb(dtxdecay, "dtxdecay", il); // step 8: compute next_state - /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, dtxdecay, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3))); - next_state = ggml_permute(ctx, next_state, 1, 0, 2, 3); + /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); cb(next_state, "next_state", il); //DEBUG -- Single chunk w/out prev state From f57dafe73a04ec0ff4ce692a9de360f27a8d6498 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 10:11:08 -0600 Subject: [PATCH 40/64] feat(eval-callback): Use -vb to set tensor print width and number of elements Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- examples/eval-callback/eval-callback.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c886..da26dfbb316c1 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -6,9 +6,17 @@ #include #include +#include #include #include +// verbosity flag set via the params.verbosity CLI flag. This is used for two +// things: +// 1. If > 0, tensors are printed with 8 digits of precision instead of 5 +// 2. If > 1, all tensor values are printed instead of the pretty-printed +// partial output +static int verbosity = 0; + /** * This the arbitrary data which will be passed to each callback. * Later on we can for example add operation or tensor name filter from the CLI arg, or a file descriptor to dump the tensor. @@ -61,6 +69,10 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * } static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + std::stringstream ss; + const int float_digits = verbosity > 0 ? 8 : 4; + ss << "%12." << float_digits << "f"; + const auto float_fmt = ss.str(); GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { @@ -93,7 +105,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne i0 = ne[0] - n; } const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); - LOG("%12.4f", v); + LOG(float_fmt.c_str(), v); if (i0 < ne[0] - 1) LOG(", "); } LOG("],\n"); @@ -153,8 +165,9 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } if (!ggml_is_quantized(t->type)) { + const int print_width = verbosity > 1 ? INT_MAX : 3; uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + ggml_print_tensor(data, t->type, t->ne, t->nb, print_width); } return true; @@ -192,6 +205,9 @@ int main(int argc, char ** argv) { common_init(); + // set verbosity for printing + verbosity = params.verbosity; + llama_backend_init(); llama_numa_init(params.numa); From 8a8706311d8c4aa9a7a43b9c66b78c6b18100ff1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 13:15:03 -0600 Subject: [PATCH 41/64] feat(ggml-cpu): Add ggml_tri_dims to support non-standard dims (with tests) This will help avoid ggml_permute and ggml_cont requirements in the SSD impl Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/include/ggml.h | 10 +++ ggml/src/ggml-cpu/ops.cpp | 125 ++++++++++++++++++++++++------------- ggml/src/ggml.c | 21 ++++++- tests/test-backend-ops.cpp | 28 +++++++-- 4 files changed, 136 insertions(+), 48 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index ae17244e78fe9..600917700e761 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2162,6 +2162,16 @@ extern "C" { int shift3); // Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value + // dim_x and dim_y specify which two dimensions to compare for triangular masking. They must have equal size. + // Default is dim_x=0, dim_y=1 (compares indices in dim 0 vs indices in dim 1) + GGML_API struct ggml_tensor * ggml_tri_dims( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype, + int dim_x, + int dim_y); + GGML_API struct ggml_tensor * ggml_tri( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b63e3a64e0b34..3927b72cc3303 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2264,46 +2264,73 @@ static void ggml_compute_forward_gelu( // ggml_compute_tri -static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; +// General implementation for arbitrary dimensions +template +static void ggml_compute_forward_tri_general( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + T c_val, + bool keep_org_val, + int dim_x, + int dim_y) { - const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; - const float c = ggml_get_op_params_f32(dst, 1); - const bool keep_org_val = isnan(c); + const ggml_tensor * src0 = dst->src[0]; - // TODO: Is ggml_is_contiguous_rows safe and sufficient? - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->ne[0] == src0->ne[1]); + GGML_ASSERT(dim_x >= 0 && dim_x < GGML_MAX_DIMS); + GGML_ASSERT(dim_y >= 0 && dim_y < GGML_MAX_DIMS); + GGML_ASSERT(dim_x != dim_y); + GGML_ASSERT(src0->ne[dim_x] == src0->ne[dim_y]); GGML_TENSOR_UNARY_OP_LOCALS - const auto [ir0, ir1] = get_thread_range(params, src0); for (int64_t ir = ir0; ir < ir1; ++ir) { const int64_t i03 = ir/(ne02*ne01); const int64_t i02 = (ir - i03*ne02*ne01)/ne01; const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - float * dst_ptr = (float *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src = (float *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); - ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype); + for (int64_t i00 = 0; i00 < ne0; ++i00) { + const T * src = (const T *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + T * dst_ptr = ( T *)(( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 + i00*nb00); + int64_t i_vals[4] = {i00, i01, i02, i03}; + int64_t iX = i_vals[dim_x]; + int64_t iY = i_vals[dim_y]; + dst_ptr[0] = _ggml_vec_tri_cmp(iX, iY, ttype) ? + (keep_org_val ? src[0] : c_val) : + type_conversion_table::from_f32(0.f); + } } - } -static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggml_tensor * dst) { +static void ggml_compute_forward_tri_f32( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + float c_val, + bool keep_org_val) { const ggml_tensor * src0 = dst->src[0]; + GGML_TENSOR_UNARY_OP_LOCALS + const auto [ir0, ir1] = get_thread_range(params, src0); - const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; - const float c = ggml_get_op_params_f32(dst, 1); - const bool keep_org_val = isnan(c); + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - // TODO: Is ggml_is_contiguous_rows safe and sufficient? - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->ne[0] == src0->ne[1]); + float * dst_ptr = (float *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const float * src = (const float *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c_val, ttype); + } +} +static void ggml_compute_forward_tri_f16( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + ggml_fp16_t c_val, + bool keep_org_val) { + const ggml_tensor * src0 = dst->src[0]; GGML_TENSOR_UNARY_OP_LOCALS - const auto [ir0, ir1] = get_thread_range(params, src0); for (int64_t ir = ir0; ir < ir1; ++ir) { @@ -2312,25 +2339,19 @@ static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggm const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); ggml_fp16_t * dst_ptr = (ggml_fp16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - ggml_fp16_t * src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); - ggml_vec_tri_f16(ne0, i01, dst_ptr, src, keep_org_val, GGML_FP32_TO_FP16(c), ttype); + const ggml_fp16_t * src = (const ggml_fp16_t *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f16(ne0, i01, dst_ptr, src, keep_org_val, c_val, ttype); } - } -static void ggml_compute_forward_tri_bf16(const ggml_compute_params * params, ggml_tensor * dst) { +static void ggml_compute_forward_tri_bf16( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + ggml_bf16_t c_val, + bool keep_org_val) { const ggml_tensor * src0 = dst->src[0]; - - const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; - const float c = ggml_get_op_params_f32(dst, 1); - const bool keep_org_val = isnan(c); - - // TODO: Is ggml_is_contiguous_rows safe and sufficient? - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->ne[0] == src0->ne[1]); - GGML_TENSOR_UNARY_OP_LOCALS - const auto [ir0, ir1] = get_thread_range(params, src0); for (int64_t ir = ir0; ir < ir1; ++ir) { @@ -2339,27 +2360,47 @@ static void ggml_compute_forward_tri_bf16(const ggml_compute_params * params, gg const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); ggml_bf16_t * dst_ptr = (ggml_bf16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - ggml_bf16_t * src = (ggml_bf16_t *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); - ggml_vec_tri_bf16(ne0, i01, dst_ptr, src, keep_org_val, GGML_FP32_TO_BF16(c), ttype); + const ggml_bf16_t * src = (const ggml_bf16_t *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_bf16(ne0, i01, dst_ptr, src, keep_org_val, c_val, ttype); } - } void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const int dim_x = ggml_get_op_params_i32(dst, 2); + const int dim_y = ggml_get_op_params_i32(dst, 3); + const bool use_general = dim_x != 0 || dim_y != 1 || !ggml_is_contiguous(src0); + const bool keep_org_val = isnan(c); + switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_tri_f32(params, dst); + if (use_general) { + ggml_compute_forward_tri_general(params, dst, ttype, c, keep_org_val, dim_x, dim_y); + } else { + ggml_compute_forward_tri_f32(params, dst, ttype, c, keep_org_val); + } } break; case GGML_TYPE_F16: { - ggml_compute_forward_tri_f16(params, dst); + ggml_fp16_t c_val = GGML_FP32_TO_FP16(c); + if (use_general) { + ggml_compute_forward_tri_general(params, dst, ttype, c_val, keep_org_val, dim_x, dim_y); + } else { + ggml_compute_forward_tri_f16(params, dst, ttype, c_val, keep_org_val); + } } break; case GGML_TYPE_BF16: { - ggml_compute_forward_tri_bf16(params, dst); + ggml_bf16_t c_val = GGML_FP32_TO_BF16(c); + if (use_general) { + ggml_compute_forward_tri_general(params, dst, ttype, c_val, keep_org_val, dim_x, dim_y); + } else { + ggml_compute_forward_tri_bf16(params, dst, ttype, c_val, keep_org_val); + } } break; default: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 39b687509e7fc..5ef728ea9751e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4997,16 +4997,25 @@ struct ggml_tensor * ggml_timestep_embedding( // ggml_tri -struct ggml_tensor * ggml_tri( +struct ggml_tensor * ggml_tri_dims( struct ggml_context * ctx, struct ggml_tensor * a, float constant, - enum ggml_tri_type tritype) { + enum ggml_tri_type tritype, + int dim_x, + int dim_y) { + + GGML_ASSERT(dim_x >= 0 && dim_x < GGML_MAX_DIMS); + GGML_ASSERT(dim_y >= 0 && dim_y < GGML_MAX_DIMS); + GGML_ASSERT(dim_x != dim_y); + GGML_ASSERT(a->ne[dim_x] == a->ne[dim_y]); struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, tritype); ggml_set_op_params_f32(result, 1, constant); + ggml_set_op_params_i32(result, 2, dim_x); + ggml_set_op_params_i32(result, 3, dim_y); result->op = GGML_OP_TRI; result->src[0] = a; @@ -5014,6 +5023,14 @@ struct ggml_tensor * ggml_tri( return result; } +struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype) { + return ggml_tri_dims(ctx, a, constant, tritype, 0, 1); +} + struct ggml_tensor * ggml_tri_keep( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 451e0e44ce7e1..8182d2a8b169c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4897,23 +4897,35 @@ struct test_tri : public test_case { const std::array ne; const ggml_tri_type tri_type; const float c; + const int64_t dim_x; + const int64_t dim_y; + const std::array permute; std::string vars() override { - return VARS_TO_STR4(type, ne, tri_type, c); + return VARS_TO_STR7(type, ne, tri_type, c, dim_x, dim_y, permute); } test_tri(ggml_tri_type tri_type, ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 1, 1}, - float c = nan("")) - : type(type), ne(ne), tri_type(tri_type), c(c) {} + float c = nan(""), + int64_t dim_x = 0, + int64_t dim_y = 1, + // don't permute by default + std::array permute = {-1, -1, -1, -1}) + : type(type), ne(ne), tri_type(tri_type), c(c), + dim_x(dim_x), dim_y(dim_y), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_tri(ctx, a, c, tri_type); + if (permute[0] != -1) { + a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]); + } + + ggml_tensor * out = ggml_tri_dims(ctx, a, c, tri_type, dim_x, dim_y); ggml_set_name(out, "out"); return out; @@ -7056,6 +7068,10 @@ static std::vector> make_test_cases_eval(int verbose test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); + // non-contiguous + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 0, 1, {2, 0, 1, 3})); + // alternate dims + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 1, 2)); for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); @@ -7229,6 +7245,10 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); + // non-contiguous + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 0, 1, {2, 0, 1, 3})); + // alternate dims + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 1, 2)); for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { From 79bce3e34986aa4d7d64987e8498354ff4e1e01e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 13:23:33 -0600 Subject: [PATCH 42/64] feat(ggml-metal): Extend metal tri imple for arbitrary dims and non-contiguous Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-impl.h | 2 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 6 +++++- ggml/src/ggml-metal/ggml-metal.metal | 19 +++++++++++++------ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8cb22668857b4..7e023b47aa242 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -606,6 +606,8 @@ typedef struct { uint64_t nb3; float c; uint32_t ttype; + int32_t dim_x; + int32_t dim_y; } ggml_metal_kargs_tri; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 03520b7c29f77..f696e3ecb6829 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1028,6 +1028,8 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; const float c = *((float *) &(op->op_params[1])); + const int32_t dim_x = (int32_t) op->op_params[2]; + const int32_t dim_y = (int32_t) op->op_params[3]; ggml_metal_kargs_tri args = { /*.ne00 =*/ ne00, @@ -1047,7 +1049,9 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, /*.nb3 =*/ nb3, /*.c =*/ c, - /*.ttype =*/ static_cast(ttype) + /*.ttype =*/ static_cast(ttype), + /*.dim_x =*/ dim_x, + /*.dim_y =*/ dim_y }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 43ecf8bdc066d..f4d1a62b4f89d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1931,17 +1931,24 @@ kernel void kernel_tri( return; } - device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + const bool keep_org_val = isnan(args.c); + const T c_val = static_cast(args.c); + const T zero_val = static_cast(0.f); // Each thread is a single element of the row if ne00 < max threads per // threadgroup, so this will loop once for each index that this thread is // responsible for - const bool keep_org_val = isnan(args.c); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) - ? (keep_org_val ? src_row[i0] : static_cast(args.c)) - : static_cast(0.f); + int64_t i_vals[4] = {i0, i1, i2, i3}; + int64_t iX = i_vals[args.dim_x]; + int64_t iY = i_vals[args.dim_y]; + + device const T * src_ptr = (device const T *) ((device const char *) src0 + i0*args.nb00 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_ptr = (device T *) ((device char *) dst + i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + dst_ptr[0] = _ggml_vec_tri_cmp(iX, iY, args.ttype) + ? (keep_org_val ? src_ptr[0] : c_val) + : zero_val; } } From 1ceb15ea61519153985f48b7f86026cb2c309da0 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 13:29:13 -0600 Subject: [PATCH 43/64] feat(ggml-cuda): Extend CUDA impl of tri to support arbitrary dims and non-cont Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/tri.cu | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index d9c4aa025dbaf..1940e368ac451 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -19,7 +19,7 @@ static __global__ void tri_kernel( const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const float c, const ggml_tri_type ttype) { + const float c, const ggml_tri_type ttype, const int dim_x, const int dim_y) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; @@ -29,16 +29,23 @@ static __global__ void tri_kernel( return; } - const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); - T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - const bool keep_org_val = isnan(c); + const T c_val = static_cast(c); + const T zero_val = static_cast(0.f); // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { - dst_row[i0] = tri_compare(i0, i1, ttype) - ? (keep_org_val ? src_row[i0] : static_cast(c)) - : static_cast(0.f); + // Create index array matching CPU implementation + int64_t i_vals[4] = {i0, i1, i2, i3}; + int64_t iX = i_vals[dim_x]; + int64_t iY = i_vals[dim_y]; + + const T * src_ptr = (const T *) ((const char *) src + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_ptr = (T *) (( char *) dst + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + dst_ptr[0] = tri_compare(iX, iY, ttype) + ? (keep_org_val ? src_ptr[0] : c_val) + : zero_val; } } @@ -48,7 +55,7 @@ static void tri_cuda( const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const float c, const ggml_tri_type ttype, + const float c, const ggml_tri_type ttype, const int dim_x, const int dim_y, cudaStream_t stream) { dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); @@ -59,7 +66,7 @@ static void tri_cuda( ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, nb0, nb1, nb2, nb3, - c, ttype + c, ttype, dim_x, dim_y ); } @@ -69,6 +76,8 @@ void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); const float c = ggml_get_op_params_f32(dst, 1); + const int dim_x = ggml_get_op_params_i32(dst, 2); + const int dim_y = ggml_get_op_params_i32(dst, 3); GGML_ASSERT(src0->type == dst->type); @@ -80,7 +89,7 @@ void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], - c, ttype, stream + c, ttype, dim_x, dim_y, stream ); } break; case GGML_TYPE_F16: @@ -90,7 +99,7 @@ void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], - c, ttype, stream + c, ttype, dim_x, dim_y, stream ); } break; case GGML_TYPE_BF16: @@ -100,7 +109,7 @@ void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], - c, ttype, stream + c, ttype, dim_x, dim_y, stream ); } break; default: From ef12069eeb7ba6bc979ac516c62b8dd7aa8bfca8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 13:33:09 -0600 Subject: [PATCH 44/64] fix: Fix INT_MAX to use numeric_limits for better compiler compat Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- examples/eval-callback/eval-callback.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index da26dfbb316c1..5c58110384aa6 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -9,6 +9,7 @@ #include #include #include +#include // verbosity flag set via the params.verbosity CLI flag. This is used for two // things: @@ -165,7 +166,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } if (!ggml_is_quantized(t->type)) { - const int print_width = verbosity > 1 ? INT_MAX : 3; + const int print_width = verbosity > 1 ? std::numeric_limits::max() : 3; uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); ggml_print_tensor(data, t->type, t->ne, t->nb, print_width); } From 3da5c97bb18c77397bd99dc3c76e6f588a1e4553 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 13:52:08 -0600 Subject: [PATCH 45/64] fix(temp): Fix CBdecay to make decay contiguous for metal We shouldn't need this once cumsum can operate on other dims and we can avoid all the various permutes elsewhere. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 652d1a9a8a752..ca475e0da9646 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11938,7 +11938,7 @@ struct llm_graph_context_mamba : public llm_graph_context { cb(decay, "decay", il); // step 5: compute surrogate_attention_matrix - /* !! */ ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay)); ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); From 3336f3c45bc83b1540bb2d6dd82b3179022f5922 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 13:53:24 -0600 Subject: [PATCH 46/64] fix: Use ggml_tri_dims to avoid perm/cont for initial decay step Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ca475e0da9646..7fa58b79bb629 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11926,12 +11926,11 @@ struct llm_graph_context_mamba : public llm_graph_context { cb(CB, "CB", il); // step 4: compute decay - ggml_tensor * dtA_tmp0 = ggml_cont(ctx, ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0)); // {1, n_seq_tokens n_head, n_seqs} - ggml_tensor * dtA_tmp1 = ggml_repeat_4d(ctx, dtA_tmp0, - dtA_tmp0->ne[0] * chunk_size_i, dtA_tmp0->ne[1], dtA_tmp0->ne[2], dtA_tmp0->ne[3]); // {n_seq_tokens, n_seq_tokens n_head, n_seqs} - ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} - ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - /* !! */ ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, + dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); + ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); + ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} cb(segsum, "segsum", il); /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} From d1e15c024a981bbe3c3ad4eeea34328731f4fbe3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 14:24:24 -0600 Subject: [PATCH 47/64] feat(ggml-cpu): Add dim arg to ggml_cumsum With tests Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/include/ggml.h | 7 ++++ ggml/src/ggml-cpu/ops.cpp | 71 ++++++++++++++++++++++++++++++++++---- ggml/src/ggml.c | 13 ++++++- tests/test-backend-ops.cpp | 22 ++++++++++-- 4 files changed, 103 insertions(+), 10 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 600917700e761..3f9ba10057838 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -988,7 +988,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // Cumulative sum along the specified dimension GGML_API struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a, + int dim); + + // Convenience function: cumulative sum along dimension 0 + GGML_API struct ggml_tensor * ggml_cumsum_0( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3927b72cc3303..4e020b41afa4f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -1397,6 +1397,50 @@ void ggml_compute_forward_sum( // ggml_compute_forward_cumsum +// General implementation for arbitrary dimensions +template +static void ggml_compute_forward_cumsum_general( + const ggml_compute_params * params, + ggml_tensor * dst, + int dim) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + + GGML_TENSOR_UNARY_OP_LOCALS + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + for (int64_t i0 = 0; i0 < ne00; i0++) { + const T * src_ptr = (const T *)((const char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + T * dst_ptr = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + // Determine position in the cumsum dimension + int64_t i_vals[4] = {i0, i1, i2, i3}; + int64_t i_dim = i_vals[dim]; + + if (i_dim == 0) { + // First element: just copy + dst_ptr[0] = src_ptr[0]; + } else { + // Accumulate: add current value to previous cumsum value + const T * prev_dst_ptr = (const T *)((const char *) dst_ptr - dst->nb[dim]); + dst_ptr[0] = type_conversion_table::from_f32( + type_conversion_table::to_f32(prev_dst_ptr[0]) + + type_conversion_table::to_f32(src_ptr[0])); + } + } + } + } + } +} + static void ggml_compute_forward_cumsum_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -1420,7 +1464,7 @@ static void ggml_compute_forward_cumsum_f32( for (int64_t i3 = 0; i3 < ne03; i3++) { for (int64_t i2 = 0; i2 < ne02; i2++) { for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + const float * src_row = (const float *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); ggml_vec_cumsum_f32(ne00, dst_row, src_row); } @@ -1451,7 +1495,7 @@ static void ggml_compute_forward_cumsum_f16( for (int64_t i3 = 0; i3 < ne03; i3++) { for (int64_t i2 = 0; i2 < ne02; i2++) { for (int64_t i1 = 0; i1 < ne01; i1++) { - ggml_fp16_t * src_row = (ggml_fp16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + const ggml_fp16_t * src_row = (const ggml_fp16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); ggml_fp16_t * dst_row = (ggml_fp16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); ggml_vec_cumsum_f16(ne00, dst_row, src_row); } @@ -1482,7 +1526,7 @@ static void ggml_compute_forward_cumsum_bf16( for (int64_t i3 = 0; i3 < ne03; i3++) { for (int64_t i2 = 0; i2 < ne02; i2++) { for (int64_t i1 = 0; i1 < ne01; i1++) { - ggml_bf16_t * src_row = (ggml_bf16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + const ggml_bf16_t * src_row = (const ggml_bf16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); ggml_bf16_t * dst_row = (ggml_bf16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); ggml_vec_cumsum_bf16(ne00, dst_row, src_row); } @@ -1496,18 +1540,33 @@ void ggml_compute_forward_cumsum( const ggml_tensor * src0 = dst->src[0]; + const int dim = ggml_get_op_params_i32(dst, 0); + const bool use_general = dim != 0 || !ggml_is_contiguous_rows(src0); + switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_cumsum_f32(params, dst); + if (use_general) { + ggml_compute_forward_cumsum_general(params, dst, dim); + } else { + ggml_compute_forward_cumsum_f32(params, dst); + } } break; case GGML_TYPE_F16: { - ggml_compute_forward_cumsum_f16(params, dst); + if (use_general) { + ggml_compute_forward_cumsum_general(params, dst, dim); + } else { + ggml_compute_forward_cumsum_f16(params, dst); + } } break; case GGML_TYPE_BF16: { - ggml_compute_forward_cumsum_bf16(params, dst); + if (use_general) { + ggml_compute_forward_cumsum_general(params, dst, dim); + } else { + ggml_compute_forward_cumsum_bf16(params, dst); + } } break; default: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5ef728ea9751e..145d993fdfd72 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2346,16 +2346,27 @@ struct ggml_tensor * ggml_sum_rows( struct ggml_tensor * ggml_cumsum( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a, + int dim) { + + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne); + ggml_set_op_params_i32(result, 0, dim); + result->op = GGML_OP_CUMSUM; result->src[0] = a; return result; } +struct ggml_tensor * ggml_cumsum_0( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cumsum(ctx, a, 0); +} + // ggml_mean struct ggml_tensor * ggml_mean( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8182d2a8b169c..922dc64e48798 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4861,14 +4861,18 @@ struct test_sum_rows : public test_case { struct test_cumsum : public test_case { const ggml_type type; const std::array ne; + const int64_t dim; + const std::array permute; std::string vars() override { return VARS_TO_STR2(type, ne); } test_cumsum(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 5, 4, 3}) - : type(type), ne(ne) {} + std::array ne = {10, 5, 4, 3}, + int64_t dim = 0, + std::array permute = {-1, -1, -1, -1}) + : type(type), ne(ne), dim(dim), permute(permute) {} double max_nmse_err() override { @@ -4884,7 +4888,11 @@ struct test_cumsum : public test_case { ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_cumsum(ctx, a); + if (permute[0] != -1) { + a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]); + } + + ggml_tensor * out = ggml_cumsum(ctx, a, dim); ggml_set_name(out, "out"); return out; @@ -7056,6 +7064,10 @@ static std::vector> make_test_cases_eval(int verbose test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + // non-contiguous + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 0, {1, 0, 2, 3})); + // alternate dim + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 1)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); @@ -7233,6 +7245,10 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + // non-contiguous + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 0, {1, 0, 2, 3})); + // alternate dim + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 1)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); From ee13af13b20542efe5bba10e75732b043a1c9481 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 14:57:17 -0600 Subject: [PATCH 48/64] feat(ggml-metal): Support arbitrary dim and non-cont in cumsum Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 21 ++++++- ggml/src/ggml-metal/ggml-metal.metal | 67 ++++++++++++++++------- tests/test-backend-ops.cpp | 2 +- 6 files changed, 68 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 8581d61ba00eb..10da4d926e17a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -320,8 +320,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); - char base[256]; char name[256]; @@ -338,7 +336,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_ } // one shared memory element for each simd group in the threadgroup - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); const int nsg = (ne00 + 31)/32; ggml_metal_pipeline_set_smem(res, nsg*sizeof(float)); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 9a5cd0f56d970..70538fef85fd1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -665,6 +665,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_TRI: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_CUMSUM: + return has_simdgroup_reduction; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7e023b47aa242..1f17f38a7124f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -585,6 +585,7 @@ typedef struct { uint64_t nb1; uint64_t nb2; uint64_t nb3; + int32_t dim; } ggml_metal_kargs_cumsum; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index f696e3ecb6829..6aeb8bae2087e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -971,6 +971,8 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + const int32_t dim = (int32_t) op->op_params[0]; + ggml_metal_kargs_cumsum args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -988,18 +990,31 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, /*.nb2 =*/ nb2, /*.nb3 =*/ nb3, + /*.dim =*/ dim }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cumsum(lib, op); + // Dimension being accumulated + const int64_t ne_dim = op->src[0]->ne[dim]; + + // Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions + int64_t grid_dims[GGML_MAX_DIMS - 1]; + int grid_idx = 0; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d != dim) { + grid_dims[grid_idx++] = op->src[0]->ne[d]; + } + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < ne_dim && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int)ne_dim); const size_t smem = ggml_metal_pipeline_get_smem(pipeline); @@ -1010,7 +1025,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, grid_dims[0], grid_dims[1], grid_dims[2], nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f4d1a62b4f89d..9ea7bf5e160ad 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1853,32 +1853,54 @@ kernel void kernel_cumsum( ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + // Figure out the dize and stride of the cumsum dim + const int64_t ne_dim = (args.dim == 0) ? args.ne00 : (args.dim == 1) ? args.ne01 : (args.dim == 2) ? args.ne02 : args.ne03; + const int64_t nb_dim_src = (args.dim == 0) ? args.nb00 : (args.dim == 1) ? args.nb01 : (args.dim == 2) ? args.nb02 : args.nb03; + const int64_t nb_dim_dst = (args.dim == 0) ? args.nb0 : (args.dim == 1) ? args.nb1 : (args.dim == 2) ? args.nb2 : args.nb3; + + // Map threadgroup indices to actual tensor dimensions + // tgpig.x, tgpig.y, tgpig.z represent the 3 non-cumsum dimensions + // tpitg.x represents position in the cumsum dimension + int64_t grid_indices[3] = {int64_t(tgpig.x), int64_t(tgpig.y), int64_t(tgpig.z)}; + int64_t i_vals[4]; + + int grid_idx = 0; + for (int d = 0; d < 4; ++d) { + if (d == args.dim) { + i_vals[d] = 0; // Will be set in the loop below + } else { + i_vals[d] = grid_indices[grid_idx++]; + } + } + + // Base index offsets. The cumsum dim will be further offset by the position + // in the threadgroup + const int64_t i0 = i_vals[0]; + const int64_t i1 = i_vals[1]; + const int64_t i2 = i_vals[2]; + const int64_t i3 = i_vals[3]; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01 || i0 >= args.ne00) { return; } - device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + // Each thread processes elements at stride ntg.x along the cumsum dimension + for (int64_t i_dim = tpitg.x; i_dim < ne_dim; i_dim += ntg.x) { + const int64_t offset_src = i0*args.nb00 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03 + i_dim*nb_dim_src; + const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst; - // Each thread is a single element of the row if ne00 < max threads per - // threadgroup, so this will loop once for each index that this thread is - // responsible for - for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + device const T * src_ptr = (device const T *) ((device const char *) src0 + offset_src); + device T * dst_ptr = (device T *) ((device char *) dst + offset_dst); - // Each thread does simd_prefix_inclusive_sum => every element of row - // now holds cumsum of the simd group - float sumf = static_cast(src_row[i0]); + // Each thread does simd_prefix_inclusive_sum + float sumf = static_cast(src_ptr[0]); sumf = simd_prefix_inclusive_sum(sumf); - dst_row[i0] = static_cast(sumf); + dst_ptr[0] = static_cast(sumf); - // If this is the last element of the simd group, store its value in - // shared memory - if (tiisg == N_SIMDWIDTH - 1 || i0 == args.ne00 - 1) { - const ushort shmem_idx = i0 / N_SIMDWIDTH; + // If this is the last element of the simd group, store its value in shared memory + if (tiisg == N_SIMDWIDTH - 1 || i_dim == ne_dim - 1) { + const ushort shmem_idx = i_dim / N_SIMDWIDTH; shmem_f32[shmem_idx] = sumf; } } @@ -1887,10 +1909,13 @@ kernel void kernel_cumsum( threadgroup_barrier(mem_flags::mem_threadgroup); // Each element then adds the final value of all preceding simd groups - for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - const ushort shmem_idx = i0 / N_SIMDWIDTH; + for (int64_t i_dim = tpitg.x; i_dim < ne_dim; i_dim += ntg.x) { + const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst; + device T * dst_ptr = (device T *) ((device char *) dst + offset_dst); + + const ushort shmem_idx = i_dim / N_SIMDWIDTH; for (ushort j = 0; j < shmem_idx; ++j) { - dst_row[i0] += static_cast(shmem_f32[j]); + dst_ptr[0] += static_cast(shmem_f32[j]); } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 922dc64e48798..c19e4e73d045e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4865,7 +4865,7 @@ struct test_cumsum : public test_case { const std::array permute; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR4(type, ne, dim, permute); } test_cumsum(ggml_type type = GGML_TYPE_F32, From 3b4055e25c72b9b55fbb82c3b02d2ff7098398af Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 24 Oct 2025 15:03:08 -0600 Subject: [PATCH 49/64] feat(ggml-cuda): Support arbitrary dims and non-cont in cumsum Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/cumsum.cu | 87 +++++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index e14be0721c699..f6c5e1c3d1aac 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,7 +1,7 @@ #include "cumsum.cuh" -// Kernel to compute cumulative sum along the innermost dimension (ne[0]) -// Each block processes one row (ne[0] elements) +// Kernel to compute cumulative sum along an arbitrary dimension +// Each block processes one position in the non-cumsum dimensions // Algorithm matches Metal implementation: // 1. Each warp computes prefix sum within itself // 2. Last thread of each warp stores result in shared memory @@ -13,36 +13,60 @@ static __global__ void cumsum_kernel( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const int dim) { // Shared memory to store warp sums (always use float for accumulation) extern __shared__ float shmem[]; - const int64_t i3 = blockIdx.z; - const int64_t i2 = blockIdx.y; - const int64_t i1 = blockIdx.x; + // Map block indices to actual tensor dimensions + // blockIdx.x, blockIdx.y, blockIdx.z represent the 3 non-cumsum dimensions + // threadIdx.x represents position in the cumsum dimension + int64_t grid_indices[3] = {blockIdx.x, blockIdx.y, blockIdx.z}; + int64_t i_vals[4]; + + int grid_idx = 0; + for (int d = 0; d < 4; ++d) { + if (d == dim) { + i_vals[d] = 0; // Will be set in the loop below + } else { + i_vals[d] = grid_indices[grid_idx++]; + } + } + + const int64_t i0 = i_vals[0]; + const int64_t i1 = i_vals[1]; + const int64_t i2 = i_vals[2]; + const int64_t i3 = i_vals[3]; - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01 || i0 >= ne00) { return; } - const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); - T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + const int64_t ne_dim = (dim == 0) ? ne00 : (dim == 1) ? ne01 : (dim == 2) ? ne02 : ne03; + const int64_t nb_dim_src = (dim == 0) ? nb00 : (dim == 1) ? nb01 : (dim == 2) ? nb02 : nb03; + const int64_t nb_dim_dst = (dim == 0) ? nb0 : (dim == 1) ? nb1 : (dim == 2) ? nb2 : nb3; const int tid = threadIdx.x; const int lane_id = tid % WARP_SIZE; // Phase 1: Each thread processes elements at stride blockDim.x // Compute warp-level prefix sums - for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim.x) { + const int64_t offset_src = i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03 + i_dim*nb_dim_src; + const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst; + + const T * src_ptr = (const T *) ((const char *) src + offset_src); + T * dst_ptr = (T *) (( char *) dst + offset_dst); + // Load value and compute prefix sum within warp - float val = static_cast(src_row[i0]); + float val = static_cast(src_ptr[0]); val = warp_prefix_inclusive_sum(val); - dst_row[i0] = static_cast(val); + dst_ptr[0] = static_cast(val); // Last thread of warp stores its sum to shared memory at position based on data index - if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { - const int shmem_idx = i0 / WARP_SIZE; + if (lane_id == WARP_SIZE - 1 || i_dim == ne_dim - 1) { + const int shmem_idx = i_dim / WARP_SIZE; shmem[shmem_idx] = val; } } @@ -51,13 +75,16 @@ static __global__ void cumsum_kernel( __syncthreads(); // Phase 2: Add the sum of all preceding warp groups to each element - for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { - const int shmem_idx = i0 / WARP_SIZE; + for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim.x) { + const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst; + T * dst_ptr = (T *) ((char *) dst + offset_dst); + + const int shmem_idx = i_dim / WARP_SIZE; float sum = 0.0f; for (int j = 0; j < shmem_idx; ++j) { sum += shmem[j]; } - dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + dst_ptr[0] = static_cast(static_cast(dst_ptr[0]) + sum); } } @@ -67,20 +94,35 @@ static void cumsum_cuda( const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const int dim, cudaStream_t stream) { + // Dimension being accumulated + const int64_t ne_dims[4] = {ne00, ne01, ne02, ne03}; + const int64_t ne_dim = ne_dims[dim]; + + // Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions + int64_t grid_dims_arr[GGML_MAX_DIMS - 1]; + int grid_idx = 0; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d != dim) { + grid_dims_arr[grid_idx++] = ne_dims[d]; + } + } + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); - dim3 grid_dims(ne01, ne02, ne03); + dim3 grid_dims(grid_dims_arr[0], grid_dims_arr[1], grid_dims_arr[2]); // Shared memory size: one float per warp - const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + const int num_warps = (ne_dim + WARP_SIZE - 1) / WARP_SIZE; const size_t shmem_size = num_warps * sizeof(float); cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, - nb0, nb1, nb2, nb3 + nb0, nb1, nb2, nb3, + dim ); } @@ -88,6 +130,8 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; cudaStream_t stream = ctx.stream(); + const int dim = ggml_get_op_params_i32(dst, 0); + GGML_ASSERT(src0->type == dst->type); switch(src0->type) { case GGML_TYPE_F32: @@ -97,6 +141,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + dim, stream ); } break; @@ -107,6 +152,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + dim, stream ); } break; @@ -117,6 +163,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + dim, stream ); } break; From 3963a72d22d0bc6483deda9f445353c8d646519a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 28 Oct 2025 17:02:23 -0600 Subject: [PATCH 50/64] feat(wip): Partially working implementation with update from previous state We will probably remove the chunking loop in favor of just using the microbatching, but we'll still need this in that case for subsequent microbatches. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 68 ++++++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7fa58b79bb629..f20196daa5036 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11856,13 +11856,16 @@ struct llm_graph_context_mamba : public llm_graph_context { // (this is necessary in order to properly use the states before they are overwritten, // while avoiding to make unnecessary copies of the states) auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // Empty y that will be extended with each chunk of tokens + ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); if (n_seq_tokens == 1) { // if (true) { //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); // If single-token, use ssm_scan op - ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); } else { //DEBUG @@ -11930,10 +11933,10 @@ struct llm_graph_context_mamba : public llm_graph_context { dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2), 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} cb(segsum, "segsum", il); /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} + decay = ggml_cont(ctx, ggml_permute(ctx, decay, 1, 0, 2, 3)); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} cb(decay, "decay", il); // step 5: compute surrogate_attention_matrix @@ -11943,16 +11946,17 @@ struct llm_graph_context_mamba : public llm_graph_context { // step 6: compute y ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) - /* !! */ ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); - y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); - cb(y, "y", il); + /* !! */ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); + cb(y_chunk, "y_chunk", il); // step 7: compute dtxdecay ggml_tensor * decay_last = ggml_view_4d(ctx, decay, decay->ne[0], 1, decay->ne[2], decay->ne[3], decay->nb[1], decay->nb[2], decay->nb[3], (decay->ne[1] - 1) * decay->nb[1]); - decay_last = ggml_permute(ctx, decay_last, 2, 0, 1, 3); + decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + cb(decay_last, "decay_last", il); B_perm = ggml_cont(ctx, B_perm); B_perm = ggml_repeat_4d(ctx, B_perm, B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); @@ -11964,22 +11968,42 @@ struct llm_graph_context_mamba : public llm_graph_context { /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); cb(next_state, "next_state", il); - //DEBUG -- Single chunk w/out prev state - ggml_tensor * out = ggml_concat(ctx, - ggml_view_1d(ctx, y, ggml_nelements(y), 0), - ggml_view_1d(ctx, next_state, ggml_nelements(next_state), 0), - 0); - return out; - - // // update previous state if present - // if (true) { - // // step 9: compute exp_dtA_cumsum - - // // step 10: compute y_prev - - // // step 11: update y from y_prev - // } + // TODO: Skip y and state updates if no previous state + // FIXME!!! These chunk-recursion parts are not working yet + + // update from previous state + ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); + cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); + ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); + cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3)))); + cb(next_state, "next_state_updated", il); + + // update from previous y + ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm); + cb(y_prev, "y_prev", il); + y_prev = ggml_mul(ctx, ggml_cont(ctx, + ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3))), + ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 0, 3))); + cb(y_prev, "y_prev_mul", il); + y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place + cb(y_chunk, "y_chunk_updated", il); + + // recurse + y = ggml_concat(ctx, y, y_chunk, 2); + cb(y, "y", il); + ssm = next_state; } + + // Concat the output y and state + ggml_tensor * out = ggml_concat(ctx, + ggml_view_1d(ctx, y, ggml_nelements(y), 0), + ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), + 0); + return out; } }; From 188ae84f191c7e0699a86d78a06327ec6817264c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 28 Oct 2025 17:06:20 -0600 Subject: [PATCH 51/64] refact: Avoid permute and cont for first cumsum Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f20196daa5036..708151c620613 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11885,9 +11885,9 @@ struct llm_graph_context_mamba : public llm_graph_context { cb(dt_softplus, "dt_softplus", il); // step 2: compute dtA and dtX - /* !! */ ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} cb(dtA, "dtA", il); - /* !! */ ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} cb(dtX, "dtX", il); // loop over all chunks @@ -11924,19 +11924,18 @@ struct llm_graph_context_mamba : public llm_graph_context { // step 3: compute CB ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} - /* !! */ ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} cb(CB, "CB", il); // step 4: compute decay ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, - dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); - ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); - ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2), 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} + ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} cb(segsum, "segsum", il); - /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} - decay = ggml_cont(ctx, ggml_permute(ctx, decay, 1, 0, 2, 3)); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs} + ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} + decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} cb(decay, "decay", il); // step 5: compute surrogate_attention_matrix @@ -11946,7 +11945,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // step 6: compute y ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) - /* !! */ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); cb(y_chunk, "y_chunk", il); @@ -11960,18 +11959,17 @@ struct llm_graph_context_mamba : public llm_graph_context { B_perm = ggml_cont(ctx, B_perm); B_perm = ggml_repeat_4d(ctx, B_perm, B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); - /* !! */ ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); cb(dtxdecay, "dtxdecay", il); // step 8: compute next_state - /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); cb(next_state, "next_state", il); // TODO: Skip y and state updates if no previous state - // FIXME!!! These chunk-recursion parts are not working yet - // update from previous state + // step 9: update from previous state ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, @@ -11982,7 +11980,7 @@ struct llm_graph_context_mamba : public llm_graph_context { next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3)))); cb(next_state, "next_state_updated", il); - // update from previous y + // step 10: update from previous y ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm); cb(y_prev, "y_prev", il); y_prev = ggml_mul(ctx, ggml_cont(ctx, @@ -11992,7 +11990,7 @@ struct llm_graph_context_mamba : public llm_graph_context { y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place cb(y_chunk, "y_chunk_updated", il); - // recurse + // step 11: recurse y = ggml_concat(ctx, y, y_chunk, 2); cb(y, "y", il); ssm = next_state; From 0441ccbf8a6f4566bc6361bf41baa9003558f0b8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 29 Oct 2025 11:15:47 -0600 Subject: [PATCH 52/64] fix: Subset input states to match ids The code now runs cleanly for parallel requests Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 708151c620613..a5cce1775ce1c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11876,6 +11876,17 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: make this configurable const uint32_t chunk_size = 256; + // extract the state(s) for the sequences identified by ids + if (ssm->ne[3] != ids->ne[0]) { + ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 + // ggml_tensor * ids_perm = ggml_permute(ctx, ids, 1, 2, 3, 0); // put the taget dim in dim 0 + ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, + ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape + ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows + ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape + GGML_ASSERT(ssm->ne[3] == ids->ne[0]); + } + // step 1: compute dt softplus // NOTE: In other implementations, the bias is added after // the softplus. This shouldn't be a problem, but it's a @@ -11944,7 +11955,7 @@ struct llm_graph_context_mamba : public llm_graph_context { cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); // step 6: compute y - ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3) + ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); cb(y_chunk, "y_chunk", il); From aba30d6ff698796cb1e1dc6c19d7c397b30035db Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 29 Oct 2025 11:16:12 -0600 Subject: [PATCH 53/64] fix: Fix the chunk size computation Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5cce1775ce1c..fa684a7b3b395 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11906,7 +11906,7 @@ struct llm_graph_context_mamba : public llm_graph_context { for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { // chunk views - const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i * chunk_size)); + const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); // slice dtA on dim 1 ggml_tensor * dtA_chunk = ggml_view_3d(ctx, dtA, dtA->ne[0], chunk_size_i, dtA->ne[2], From 62ac8977377bb6be979061dd3de994d7124b27ae Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 29 Oct 2025 11:38:39 -0600 Subject: [PATCH 54/64] fix: Fix handling of batch size > 1 in chunk updates Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fa684a7b3b395..02df96f0c13b0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11981,24 +11981,25 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: Skip y and state updates if no previous state // step 9: update from previous state - ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); + ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); // {n_head, chunk_size_i, n_seqs} cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], - (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs} cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); - next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3)))); + ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 1, 2, 3, 0); + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm))); cb(next_state, "next_state_updated", il); // step 10: update from previous y ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm); cb(y_prev, "y_prev", il); - y_prev = ggml_mul(ctx, ggml_cont(ctx, - ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3))), - ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 0, 3))); + y_prev = ggml_mul(ctx, + ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), + ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0))); cb(y_prev, "y_prev_mul", il); - y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place + y_chunk = ggml_add(ctx, y_chunk, y_prev); cb(y_chunk, "y_chunk_updated", il); // step 11: recurse From 36244fe2f896478ea257f39fd6470841fc2318ac Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 29 Oct 2025 12:37:18 -0600 Subject: [PATCH 55/64] fix: Fix permutation for nemotron-h shape Something is definitely still broken for nemotron-h which may be the g > 1 aspect of the model Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 02df96f0c13b0..580b6e1691ebf 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11879,13 +11879,13 @@ struct llm_graph_context_mamba : public llm_graph_context { // extract the state(s) for the sequences identified by ids if (ssm->ne[3] != ids->ne[0]) { ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 - // ggml_tensor * ids_perm = ggml_permute(ctx, ids, 1, 2, 3, 0); // put the taget dim in dim 0 ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape GGML_ASSERT(ssm->ne[3] == ids->ne[0]); } + // ssm -> {d_state, head_dim, n_head, n_seqs} // step 1: compute dt softplus // NOTE: In other implementations, the bias is added after @@ -11988,7 +11988,7 @@ struct llm_graph_context_mamba : public llm_graph_context { exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs} cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); - ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 1, 2, 3, 0); + ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm))); cb(next_state, "next_state_updated", il); From 8b6f38a2522b3126a866fc7e2af01c3f897d743d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 14:12:13 -0700 Subject: [PATCH 56/64] feat(off-topic): print the number of elements in tensors with llama-gguf Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- examples/gguf/gguf.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index 1bf8e705e359c..d28ede1c12927 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -184,9 +184,12 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { const char * name = gguf_get_tensor_name (ctx, i); const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); - const char * type = ggml_type_name(gguf_get_tensor_type(ctx, i)); + const auto type = gguf_get_tensor_type(ctx, i); + const char * type_name = ggml_type_name(type); + const size_t type_size = ggml_type_size(type); + const size_t n_elements = size / type_size; - printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s\n", __func__, i, name, size, offset, type); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s, n_elts = %zu\n", __func__, i, name, size, offset, type_name, n_elements); } } From 82bba1daec150362e5ebaee793c79672c07170d2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 16:32:01 -0700 Subject: [PATCH 57/64] feat(ggml-cpu): Add f16 and bf16 support for ssm_conv Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 157 ++++++++++++++++++++++++++++++++++---- 1 file changed, 143 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fd9e52277ce7e..97419d324fe66 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8918,7 +8918,8 @@ void ggml_compute_forward_flash_attn_back( // ggml_compute_forward_ssm_conv -static void ggml_compute_forward_ssm_conv_f32( +template +static void ggml_compute_forward_ssm_conv_impl( const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // conv_x @@ -8934,9 +8935,10 @@ static void ggml_compute_forward_ssm_conv_f32( const int n_s = dst->ne[2]; // number of sequences in the batch GGML_ASSERT( dst->ne[0] == nr); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(src_t)); + GGML_ASSERT(src1->nb[0] == sizeof(conv_t)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(src_t)); + GGML_ASSERT(dst->type == src0->type); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -8950,9 +8952,9 @@ static void ggml_compute_forward_ssm_conv_f32( for (int i2 = 0; i2 < n_t; ++i2) { // {d_conv - 1 + n_t, d_inner, n_seqs} // sliding window - const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} - float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + const src_t * s = (const src_t *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const conv_t * c = (const conv_t *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + src_t * x = ( src_t *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} // TODO: transpose the output for smaller strides for big batches? // d_inner @@ -8963,13 +8965,80 @@ static void ggml_compute_forward_ssm_conv_f32( // d_conv for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + sumf += type_conversion_table::to_f32(s[i0 + i1*ncs]) * type_conversion_table::to_f32(c[i0 + i1*nc]); } - x[i1] = sumf; - } - } - } -} + x[i1] = type_conversion_table::from_f32(sumf); + } + } + } +} + +// static void ggml_compute_forward_ssm_conv_q_f32( +// const ggml_compute_params * params, +// ggml_tensor * dst) { +// const ggml_tensor * src0 = dst->src[0]; // conv_x +// const ggml_tensor * src1 = dst->src[1]; // conv1d.weight + +// const int ith = params->ith; +// const int nth = params->nth; + +// const int nc = src1->ne[0]; // d_conv +// const int ncs = src0->ne[0]; // d_conv - 1 + n_t +// const int nr = src0->ne[1]; // d_inner +// const int n_t = dst->ne[1]; // tokens per sequence +// const int n_s = dst->ne[2]; // number of sequences in the batch + +// const ggml_type type0 = src0->type; +// const size_t type0_size = ggml_type_size(type0); +// ggml_to_float_t const dequantize_row0_q = ggml_get_type_traits(type0)->to_float; +// ggml_from_float_t const quantize_row0_q = ggml_get_type_traits_cpu(type0)->from_float; + +// const ggml_type type1 = src1->type; +// const size_t type1_size = ggml_type_size(type1); +// ggml_to_float_t const dequantize_row1_q = ggml_get_type_traits(type1)->to_float; +// ggml_from_float_t const quantize_row1_q = ggml_get_type_traits_cpu(type1)->from_float; + +// GGML_ASSERT( dst->ne[0] == nr); +// GGML_ASSERT(src0->nb[0] == type0_size); +// GGML_ASSERT(src1->nb[0] == type1_size); +// GGML_ASSERT(src0->nb[1] == src0->ne[0]*type0_size); +// GGML_ASSERT(dst->type == src0->type); + +// // rows per thread +// const int dr = (nr + nth - 1)/nth; + +// // row range for this thread +// const int ir0 = dr*ith; +// const int ir1 = MIN(ir0 + dr, nr); +// const int ir = ir1 - ir0; + +// // temporary storage for dequantized lines +// float * wdata = (float *) params->wdata + (src0->ne[0] + CACHE_LINE_SIZE_F32) * ith; + +// for (int i3 = 0; i3 < n_s; ++i3) { +// for (int i2 = 0; i2 < n_t; ++i2) { +// // {d_conv - 1 + n_t, d_inner, n_seqs} +// // sliding window +// const void * s = (const void *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} +// const void * c = (const void *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} +// void * x = ( void *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + +// // TODO: transpose the output for smaller strides for big batches? +// // d_inner +// for (int i1 = 0; i1 < ir; ++i1) { +// // rowwise dot product +// // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision +// float sumf = 0.0f; + +// // d_conv +// for (int i0 = 0; i0 < nc; ++i0) { +// sumf += type_conversion_table::to_f32(s[i0 + i1*ncs]) * type_conversion_table::to_f32(c[i0 + i1*nc]); +// } +// x[i1] = type_conversion_table::from_f32(sumf); +// } +// } +// } +// } void ggml_compute_forward_ssm_conv( const ggml_compute_params * params, @@ -8977,8 +9046,68 @@ void ggml_compute_forward_ssm_conv( switch (dst->src[0]->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_conv_f32(params, dst); + switch (dst->src[1]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + } break; + case GGML_TYPE_F16: + { + switch (dst->src[1]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + } break; + case GGML_TYPE_BF16: + { + switch (dst->src[1]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } } break; + // TODO: Support quantized types default: { GGML_ABORT("fatal error"); From 7ad0f37e667c8ca787834d26482a49792bb59de7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 16:33:46 -0700 Subject: [PATCH 58/64] feat(llama-quant): Allow F16 and BF16 quants of ssm_conv1d.weight This is experimantal! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-quant.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a56b2626ae1c5..428f3b8db9751 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -421,6 +421,18 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } ++qs.i_ffn_up; } + else if (name.find("ssm_conv1d") != std::string::npos) { + // go as low as F16 for now + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: + case LLAMA_FTYPE_MOSTLY_BF16: + break; + default: + { + new_type = GGML_TYPE_F16; + } + } + } // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; //} @@ -859,9 +871,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - // do not quantize Mamba's small yet 2D weights + // do not quantize shortconv 2D weights // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights From 6256f9a8118beca03f1645ca2c22671f0a7ed68c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 16:35:06 -0700 Subject: [PATCH 59/64] feat(ggml-cpu): Add partial implementation of scale for f16 This is used to zero-out the state in build_rs, so it's required to support F16 cache states for recurrent models. The bias route does not get hit in that case, but would need to be implemented if used elsewhere. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cpu/ops.cpp | 58 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 97419d324fe66..009da7cfe6c3c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4558,6 +4558,60 @@ static void ggml_compute_forward_scale_f32( } } +static void ggml_compute_forward_scale_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + float s; // scale factor + float b; // bias + + memcpy(&s, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&b, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + if (b == 0.0f) { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(ggml_fp16_t)); + } + ggml_vec_scale_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*nb1), s); + } + } else { + //TODO: support bias! + GGML_ABORT("fatal error"); + // for (int i1 = ir0; i1 < ir1; i1++) { + // ggml_vec_mad1_f16(nc, + // (ggml_fp16_t *) ((char *) dst->data + i1*nb1), + // (ggml_fp16_t *) ((char *) src0->data + i1*nb1), + // s, b); + // } + } +} + void ggml_compute_forward_scale( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4569,6 +4623,10 @@ void ggml_compute_forward_scale( { ggml_compute_forward_scale_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_scale_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); From 204cd80ed37403f3dc17d121a4909e4d882b3573 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 16:35:50 -0700 Subject: [PATCH 60/64] feat(wip): Use type_k/type_v for hybrid cache types Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 896725466ce24..4114c484b60cb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6786,8 +6786,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_type_r */ params.type_k, + /* recurrent_type_s */ params.type_v, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, From 86788a2431cd73719202518626d20aa06f0f5e59 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 16:36:36 -0700 Subject: [PATCH 61/64] temp: Cast ssm to F32 This will be needed until F16 support is added for SSM_SCAN Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/models/graph-context-mamba.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index 08bfd38b5e6ac..575f2290e52b8 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -242,6 +242,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // while avoiding to make unnecessary copies of the states) auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); // Empty y that will be extended with each chunk of tokens ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); From de43d0b9f3d04bb80c6be535c725910a15cdd571 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 4 Nov 2025 16:57:19 -0700 Subject: [PATCH 62/64] feat(ggml-metal): Add support for F16 and BF16 ssm_conv weights Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 5 ++-- ggml/src/ggml-metal/ggml-metal.metal | 28 ++++++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 6b7efa6a2650d..20106d89fc69f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -218,7 +218,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t }; const char * suffix = ""; - if (n % 4 == 0) { + if (n % 4 == 0 && op->type == GGML_TYPE_F32) { suffix = "_4"; } @@ -394,7 +394,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(ggml_is_contiguous(op->src[1])); @@ -404,7 +403,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar const char * suffix = ""; - if (op->src[1]->ne[0] % 4 == 0) { + if (op->src[1]->ne[0] % 4 == 0 && op->src[1]->type == GGML_TYPE_F32) { suffix = "_4"; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ae6791165ae38..482a3a400b847 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1235,6 +1235,14 @@ kernel void kernel_scale_f32( dst[tpig] = src0[tpig] * args.scale + args.bias; } +kernel void kernel_scale_f16( + constant ggml_metal_kargs_scale & args, + device const half * src0, + device half * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * args.scale + args.bias; +} + kernel void kernel_scale_f32_4( constant ggml_metal_kargs_scale & args, device const float4 * src0, @@ -2207,8 +2215,9 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -kernel void kernel_ssm_conv_f32_f32( +// ref: ggml.c:ggml_compute_forward_ssm_conv_impl +template +kernel void kernel_ssm_conv_impl( constant ggml_metal_kargs_ssm_conv & args, device const void * src0, device const void * src1, @@ -2226,14 +2235,14 @@ kernel void kernel_ssm_conv_f32_f32( //const int64_t n_t = args.ne1; //const int64_t n_s = args.ne2; - device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); - device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); - device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + device const src_t * s = (device const src_t *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const conv_t * c = (device const conv_t *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - sumf += s[i0] * c[i0]; + sumf += static_cast(s[i0]) * static_cast(c[i0]); } x[0] = sumf; @@ -2270,6 +2279,13 @@ kernel void kernel_ssm_conv_f32_f32_4( x[0] = sumf; } +typedef decltype(kernel_ssm_conv_impl) kernel_ssm_conv_t; +template [[host_name("kernel_ssm_conv_f32_f32")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl; +template [[host_name("kernel_ssm_conv_f32_f16")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_ssm_conv_f32_bf16")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl; +#endif + // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, From 426a97c8dec200f84c251ff48c77e1bd0199f908 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 5 Nov 2025 11:47:34 -0700 Subject: [PATCH 63/64] feat: Keep ssm in f16 until output on SSD code path Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/models/graph-context-mamba.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index 575f2290e52b8..cab60f5c2096c 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -242,7 +242,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // while avoiding to make unnecessary copies of the states) auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); // Empty y that will be extended with each chunk of tokens ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); @@ -252,6 +251,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); // If single-token, use ssm_scan op + ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); } else { //DEBUG @@ -362,6 +362,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // step 8: compute next_state ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + if (next_state->type != ssm->type) { + next_state = ggml_cast(ctx, next_state, ssm->type); + } cb(next_state, "next_state", il); // TODO: Skip y and state updates if no previous state @@ -395,6 +398,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i } // Concat the output y and state + if (ssm->type != y->type) { + ssm = ggml_cast(ctx, ssm, y->type); + } ggml_tensor * out = ggml_concat(ctx, ggml_view_1d(ctx, y, ggml_nelements(y), 0), ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), From 6733bdaba995f0724ba9f9af6e1ba79508745c82 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 5 Nov 2025 12:07:38 -0700 Subject: [PATCH 64/64] feat: Remove sub-ubatch batching Unlike Qwen3Next, we don't hit big commplexity scaling issues here, so removing all of the batching gives a big reduction in complexity and a big boost to performance! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/models/graph-context-mamba.cpp | 185 +++++++++++------------------ 1 file changed, 72 insertions(+), 113 deletions(-) diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index cab60f5c2096c..13b70f88e3e39 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -243,9 +243,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - // Empty y that will be extended with each chunk of tokens - ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); - if (n_seq_tokens == 1) { // if (true) { //DEBUG @@ -259,9 +256,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // otherwise, use the SSD formulation - // TODO: make this configurable - const uint32_t chunk_size = 256; - // extract the state(s) for the sequences identified by ids if (ssm->ne[3] != ids->ne[0]) { ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 @@ -287,115 +281,80 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} cb(dtX, "dtX", il); - // loop over all chunks + + // step 3: compute CB uint32_t repeats = n_head / n_group; - for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { - - // chunk views - const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); - // slice dtA on dim 1 - ggml_tensor * dtA_chunk = ggml_view_3d(ctx, dtA, - dtA->ne[0], chunk_size_i, dtA->ne[2], - dtA->nb[1], dtA->nb[2], - chunk_i * dtA->nb[1]); - cb(dtA_chunk, "dtA_chunk", il); - // slice dtX on dim 2 - ggml_tensor * dtX_chunk = ggml_view_4d(ctx, dtX, - dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], - dtX->nb[1], dtX->nb[2], dtX->nb[3], - chunk_i * dtX->nb[2]); - cb(dtX_chunk, "dtX_chunk", il); - // slice B on dim 2 - ggml_tensor * B_chunk = ggml_view_4d(ctx, B, - B->ne[0], B->ne[1], chunk_size_i, B->ne[3], - B->nb[1], B->nb[2], B->nb[3], - chunk_i * B->nb[2]); - cb(B_chunk, "B_chunk", il); - // slice C on dim 2 - ggml_tensor * C_chunk = ggml_view_4d(ctx, C, - C->ne[0], C->ne[1], chunk_size_i, C->ne[3], - C->nb[1], C->nb[2], C->nb[3], - chunk_i * C->nb[2]); - cb(C_chunk, "C_chunk", il); - - // step 3: compute CB - ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} - ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} - ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} - CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} - cb(CB, "CB", il); - - // step 4: compute decay - ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, - dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} - ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} - ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} - cb(segsum, "segsum", il); - ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1} - decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} - cb(decay, "decay", il); - - // step 5: compute surrogate_attention_matrix - ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay)); - ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); - cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); - - // step 6: compute y - ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); - ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); - y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); - cb(y_chunk, "y_chunk", il); - - // step 7: compute dtxdecay - ggml_tensor * decay_last = ggml_view_4d(ctx, decay, - decay->ne[0], 1, decay->ne[2], decay->ne[3], - decay->nb[1], decay->nb[2], decay->nb[3], - (decay->ne[1] - 1) * decay->nb[1]); - decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); - cb(decay_last, "decay_last", il); - B_perm = ggml_cont(ctx, B_perm); - B_perm = ggml_repeat_4d(ctx, B_perm, - B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); - ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); - dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); - cb(dtxdecay, "dtxdecay", il); - - // step 8: compute next_state - ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); - if (next_state->type != ssm->type) { - next_state = ggml_cast(ctx, next_state, ssm->type); - } - cb(next_state, "next_state", il); - - // TODO: Skip y and state updates if no previous state - - // step 9: update from previous state - ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); // {n_head, chunk_size_i, n_seqs} - cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); - ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, - exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], - exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], - (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs} - cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); - ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} - next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm))); - cb(next_state, "next_state_updated", il); - - // step 10: update from previous y - ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm); - cb(y_prev, "y_prev", il); - y_prev = ggml_mul(ctx, - ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), - ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0))); - cb(y_prev, "y_prev_mul", il); - y_chunk = ggml_add(ctx, y_chunk, y_prev); - cb(y_chunk, "y_chunk_updated", il); - - // step 11: recurse - y = ggml_concat(ctx, y, y_chunk, 2); - cb(y, "y", il); - ssm = next_state; + ggml_tensor * C_perm = ggml_permute(ctx, C, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * B_perm = ggml_permute(ctx, B, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} + cb(CB, "CB", il); + + // step 4: compute decay + ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA, + dtA->ne[0], dtA->ne[1], dtA->ne[2], dtA->ne[3] * n_seq_tokens); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + cb(segsum, "segsum", il); + ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + cb(decay, "decay", il); + + // step 5: compute surrogate_attention_matrix + ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay)); + ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); + + // step 6: compute y + ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX, 1, 2, 0, 3)); + ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); + cb(y, "y", il); + + // step 7: compute dtxdecay + ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + decay->ne[0], 1, decay->ne[2], decay->ne[3], + decay->nb[1], decay->nb[2], decay->nb[3], + (decay->ne[1] - 1) * decay->nb[1]); + decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + cb(decay_last, "decay_last", il); + B_perm = ggml_cont(ctx, B_perm); + B_perm = ggml_repeat_4d(ctx, B_perm, + B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX, decay_last); + dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + cb(dtxdecay, "dtxdecay", il); + + // step 8: compute next_state + ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + if (next_state->type != ssm->type) { + next_state = ggml_cast(ctx, next_state, ssm->type); } + cb(next_state, "next_state", il); + + // TODO: Skip y and state updates if no previous state + + // step 9: update from previous state + ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA, 1)); // {n_head, chunk_size_i, n_seqs} + cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); + ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs} + cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); + ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm))); + cb(next_state, "next_state_updated", il); + + // step 10: update from previous y + ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C, 0, 2, 1, 3), ssm); + cb(y_prev, "y_prev", il); + y_prev = ggml_mul(ctx, + ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), + ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0))); + cb(y_prev, "y_prev_mul", il); + y = ggml_add(ctx, y, y_prev); + cb(y, "y_updated", il); // Concat the output y and state if (ssm->type != y->type) {