@@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan(
72707270 struct ggml_tensor * dt,
72717271 struct ggml_tensor * A,
72727272 struct ggml_tensor * B,
7273- struct ggml_tensor * C) {
7273+ struct ggml_tensor * C,
7274+ struct ggml_tensor * D) {
72747275 GGML_ASSERT(ggml_is_contiguous(s));
7275- GGML_ASSERT(ggml_is_contiguous(x));
72767276 GGML_ASSERT(ggml_is_contiguous(dt));
72777277 GGML_ASSERT(ggml_is_contiguous(A));
7278- GGML_ASSERT(ggml_is_matrix(A));
7279- GGML_ASSERT(ggml_is_3d(B));
7280- GGML_ASSERT(ggml_is_3d(s));
7278+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
72817279 GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
72827280 GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
7283- GGML_ASSERT(ggml_are_same_shape(x, dt));
7281+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
7282+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
7283+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
72847284 GGML_ASSERT(ggml_are_same_shape(B, C));
72857285
72867286 {
72877287 const int64_t d_state = s->ne[0];
7288- const int64_t d_inner = s->ne[1];
7289- const int64_t n_seq_tokens = x->ne[1];
7290- const int64_t n_seqs = x->ne[2];
7291-
7292- GGML_ASSERT(s->ne[2] == n_seqs);
7293- GGML_ASSERT(x->ne[0] == d_inner);
7294- GGML_ASSERT(A->ne[0] == d_state);
7295- GGML_ASSERT(A->ne[1] == d_inner);
7288+ const int64_t head_dim = x->ne[0];
7289+ const int64_t n_head = x->ne[1];
7290+ const int64_t n_seq_tokens = x->ne[2];
7291+ const int64_t n_seqs = x->ne[3];
7292+
7293+ GGML_ASSERT(dt->ne[0] == n_head);
7294+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
7295+ GGML_ASSERT(dt->ne[2] == n_seqs);
7296+ GGML_ASSERT(ggml_is_3d(dt));
7297+ GGML_ASSERT(s->ne[1] == head_dim);
7298+ GGML_ASSERT(s->ne[2] == n_head);
7299+ GGML_ASSERT(s->ne[3] == n_seqs);
72967300 GGML_ASSERT(B->ne[0] == d_state);
7297- GGML_ASSERT(B->ne[1] == n_seq_tokens);
7298- GGML_ASSERT(B->ne[2] == n_seqs);
7301+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
7302+ GGML_ASSERT(B->ne[3] == n_seqs);
7303+ GGML_ASSERT(D->ne[0] == n_head);
7304+ GGML_ASSERT(ggml_is_vector(D));
7305+
7306+ if (ggml_is_vector(A)) {
7307+ // Mamba-2
7308+ GGML_ASSERT(A->ne[0] == n_head);
7309+ } else {
7310+ // Mamba-1
7311+ GGML_ASSERT(A->ne[0] == d_state);
7312+ GGML_ASSERT(A->ne[1] == n_head);
7313+ GGML_ASSERT(ggml_is_matrix(A));
7314+ }
72997315 }
73007316
73017317 bool is_node = false;
@@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan(
73167332 result->src[3] = A;
73177333 result->src[4] = B;
73187334 result->src[5] = C;
7335+ result->src[6] = D;
73197336
73207337 return result;
73217338}
@@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv(
1584015857static void ggml_compute_forward_ssm_scan_f32(
1584115858 const struct ggml_compute_params * params,
1584215859 struct ggml_tensor * dst) {
15843- const struct ggml_tensor * src0 = dst->src[0]; // s
15844- const struct ggml_tensor * src1 = dst->src[1]; // x
15845- const struct ggml_tensor * src2 = dst->src[2]; // dt
15846- const struct ggml_tensor * src3 = dst->src[3]; // A
15847- const struct ggml_tensor * src4 = dst->src[4]; // B
15848- const struct ggml_tensor * src5 = dst->src[5]; // C
15860+ const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
15861+ const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
15862+ const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
15863+ const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
15864+ const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
15865+ const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
15866+ const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
1584915867
1585015868 const int ith = params->ith;
1585115869 const int nth = params->nth;
1585215870
15853- const int64_t nc = src0->ne[0]; // d_state
15854- const int64_t nr = src0->ne[1]; // d_inner
15855- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
15856- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
15871+ const int64_t nc = src0->ne[0]; // d_state
15872+ const int64_t nr = src0->ne[1]; // dim
15873+ const int64_t nh = src1->ne[1]; // n_head
15874+ const int64_t ng = src4->ne[1];
15875+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
15876+ const int64_t ns = src0->ne[3]; // number of sequences in the batch
15877+
15878+ const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
1585715879
1585815880 GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
1585915881 GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32(
1586215884 GGML_ASSERT(src3->nb[0] == sizeof(float));
1586315885 GGML_ASSERT(src4->nb[0] == sizeof(float));
1586415886 GGML_ASSERT(src5->nb[0] == sizeof(float));
15865- // required for the dot product between s and C
15866- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15867- // required for per-sequence offsets for states
15868- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
15869- // required to get correct offset for state destination (i.e. src1->nb[3])
15870- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
15871-
15872- // rows per thread
15873- const int dr = (nr + nth - 1)/nth;
15874-
15875- // row range for this thread
15876- const int ir0 = dr*ith;
15877- const int ir1 = MIN(ir0 + dr, nr);
15878- const int ir = ir1 - ir0;
15879-
15880- for (int i3 = 0; i3 < n_s; ++i3) {
15881- for (int i2 = 0; i2 < n_t; ++i2) {
15882- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
15883- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15884- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
15885- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15886- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
15887- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
15888- float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15889- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
15890-
15891- // use the output as the source for the next token-wise iterations
15887+ GGML_ASSERT(src6->nb[0] == sizeof(float));
15888+ // allows optimizing the modulo since n_group should be a power of 2
15889+ GGML_ASSERT((ng & -ng) == ng);
15890+
15891+ // heads per thread
15892+ const int dh = (nh + nth - 1)/nth;
15893+
15894+ // head range for this thread
15895+ const int ih0 = dh*ith;
15896+ const int ih1 = MIN(ih0 + dh, nh);
15897+
15898+ for (int i3 = 0; i3 < ns; ++i3) {
15899+ for (int i2 = 0; i2 < nt; ++i2) {
15900+ const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
15901+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
15902+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
15903+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
15904+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
15905+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
15906+ const float * D = (const float *) ((const char *) src6->data); // {nh}
15907+ float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
15908+ float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
15909+
15910+ // use the output as the source when it's not the first token-wise iteration
1589215911 if (i2 > 0) { s0 = s; }
1589315912
15894- // d_inner
15895- for (int i1 = 0; i1 < ir; ++i1) {
15896- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15897- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15898- float x_dt = x[i1] * dt_soft_plus;
15899- float sumf = 0.0f;
15900- // d_state
15901- for (int i0 = 0; i0 < nc; ++i0) {
15902- int i = i0 + i1*nc;
15903- // state = prev_state * dA + dB * x
15904- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15905- // y = rowwise_dotprod(state, C)
15906- sumf += state * C[i0];
15907- s[i] = state;
15913+ if (ggml_is_vector(src3)) {
15914+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
15915+
15916+ // n_head
15917+ for (int h = ih0; h < ih1; ++h) {
15918+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
15919+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15920+ const float dA = expf(dt_soft_plus * A[h]);
15921+
15922+ // TODO: SIMD implementation
15923+ // dim
15924+ for (int i1 = 0; i1 < nr; ++i1) {
15925+ const int i = i1 + h*nr;
15926+ const float x_dt = x[i] * dt_soft_plus;
15927+ float sumf = 0.0f;
15928+ // d_state
15929+ for (int i0 = 0; i0 < nc; ++i0) {
15930+ const int ii = i0 + i*nc;
15931+ const int ig = i0 + (h & (ng - 1))*nc;
15932+ // state = prev_state * dA + dB * x
15933+ const float state = (s0[ii] * dA) + (B[ig] * x_dt);
15934+ // y = rowwise_dotprod(state, C)
15935+ sumf += state * C[ig];
15936+ s[ii] = state;
15937+ }
15938+ y[i] = sumf + x[i] * D[h];
15939+ }
15940+ }
15941+ } else {
15942+ // Mamba-1 has an element-wise decay factor for the states
15943+
15944+ // n_head
15945+ for (int h = ih0; h < ih1; ++h) {
15946+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
15947+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15948+
15949+ // dim
15950+ for (int i1 = 0; i1 < nr; ++i1) {
15951+ const int i = i1 + h*nr;
15952+ const float x_dt = x[i] * dt_soft_plus;
15953+ float sumf = 0.0f;
15954+ // d_state
15955+ for (int i0 = 0; i0 < nc; ++i0) {
15956+ const int ii = i0 + i*nc;
15957+ const int ig = i0 + (h & (ng - 1))*nc;
15958+ // state = prev_state * dA + dB * x
15959+ const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
15960+ // y = rowwise_dotprod(state, C)
15961+ sumf += state * C[ig];
15962+ s[ii] = state;
15963+ }
15964+ y[i] = sumf + x[i] * D[h];
15965+ }
1590815966 }
15909- y[i1] = sumf;
1591015967 }
1591115968 }
1591215969 }
0 commit comments