@@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32(
1020710207 GGML_ASSERT( nb0 == sizeof(float));
1020810208 GGML_ASSERT(nb00 == sizeof(float));
1020910209
10210- if (nb10 == sizeof(float)) {
10210+ if (ne00 > 1 && ne10 == 1) {
10211+ // fast broadcast path
10212+ for (int64_t ir = ith; ir < nr; ir += nth) {
10213+ // src0 and dst are same shape => same indices
10214+ const int64_t i03 = ir/(ne02*ne01);
10215+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10216+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10217+
10218+ const int64_t i13 = i03 % ne13;
10219+ const int64_t i12 = i02 % ne12;
10220+ const int64_t i11 = i01 % ne11;
10221+
10222+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10223+
10224+ const float scale = src1_ptr[0];
10225+
10226+ if (scale == 0.0f) {
10227+ // NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
10228+ // but it is useful when resetting the state of recurrent models.
10229+ memset((char *)dst->data + ir*nb1, 0, nb1);
10230+ } else {
10231+ if (dst->data != src0->data) {
10232+ // src0 is same shape as dst => same indices
10233+ memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
10234+ }
10235+ if (scale != 1.0f) {
10236+ ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
10237+ }
10238+ }
10239+ }
10240+ } else if (nb10 == sizeof(float)) {
1021110241 for (int64_t ir = ith; ir < nr; ir += nth) {
1021210242 // src0 and dst are same shape => same indices
1021310243 const int64_t i03 = ir/(ne02*ne01);
@@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32(
1591915949 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
1592015950 const float dA = expf(dt_soft_plus * A[h]);
1592115951
15922- // TODO: SIMD implementation
1592315952 // dim
1592415953 for (int i1 = 0; i1 < nr; ++i1) {
15925- const int i = i1 + h*nr;
15926- const float x_dt = x[i ] * dt_soft_plus;
15954+ const int ii = i1 + h*nr;
15955+ const float x_dt = x[ii ] * dt_soft_plus;
1592715956 float sumf = 0.0f;
15957+ #if defined(GGML_SIMD)
15958+ const int np = (nc & ~(GGML_F32_STEP - 1));
15959+
15960+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
15961+
15962+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
15963+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
15964+
15965+ GGML_F32_VEC ax[GGML_F32_ARR];
15966+ GGML_F32_VEC ay[GGML_F32_ARR];
15967+ GGML_F32_VEC az[GGML_F32_ARR];
15968+
15969+ for (int i = 0; i < np; i += GGML_F32_STEP) {
15970+ for (int j = 0; j < GGML_F32_ARR; j++) {
15971+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
15972+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
15973+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
15974+
15975+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
15976+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
15977+
15978+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
15979+
15980+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
15981+
15982+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
15983+ }
15984+ }
15985+
15986+ // reduce sum0..sum3 to sum0
15987+ GGML_F32_VEC_REDUCE(sumf, sum);
15988+ #else
15989+ const int np = 0;
15990+ #endif
1592815991 // d_state
15929- for (int i0 = 0 ; i0 < nc; ++i0) {
15930- const int ii = i0 + i *nc;
15992+ for (int i0 = np ; i0 < nc; ++i0) {
15993+ const int i = i0 + ii *nc;
1593115994 const int ig = i0 + (h & (ng - 1))*nc;
1593215995 // state = prev_state * dA + dB * x
15933- const float state = (s0[ii ] * dA) + (B[ig] * x_dt);
15996+ const float state = (s0[i ] * dA) + (B[ig] * x_dt);
1593415997 // y = rowwise_dotprod(state, C)
1593515998 sumf += state * C[ig];
15936- s[ii ] = state;
15999+ s[i ] = state;
1593716000 }
15938- y[i ] = sumf + x[i ] * D[h];
16001+ y[ii ] = sumf + x[ii ] * D[h];
1593916002 }
1594016003 }
1594116004 } else {
@@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32(
1594816011
1594916012 // dim
1595016013 for (int i1 = 0; i1 < nr; ++i1) {
15951- const int i = i1 + h*nr;
15952- const float x_dt = x[i ] * dt_soft_plus;
16014+ const int ii = i1 + h*nr;
16015+ const float x_dt = x[ii ] * dt_soft_plus;
1595316016 float sumf = 0.0f;
16017+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
16018+ // and also because expf is used within the loop.
1595416019 // d_state
1595516020 for (int i0 = 0; i0 < nc; ++i0) {
15956- const int ii = i0 + i *nc;
16021+ const int i = i0 + ii *nc;
1595716022 const int ig = i0 + (h & (ng - 1))*nc;
1595816023 // state = prev_state * dA + dB * x
15959- const float state = (s0[ii ] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
16024+ const float state = (s0[i ] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
1596016025 // y = rowwise_dotprod(state, C)
1596116026 sumf += state * C[ig];
15962- s[ii ] = state;
16027+ s[i ] = state;
1596316028 }
15964- y[i ] = sumf + x[i ] * D[h];
16029+ y[ii ] = sumf + x[ii ] * D[h];
1596516030 }
1596616031 }
1596716032 }
0 commit comments