@@ -1721,50 +1721,67 @@ void llama_context::build_attn_inp(
17211721ggml_tensor * llama_context::build_attn (
17221722 ggml_context * ctx0,
17231723 ggml_cgraph * gf,
1724- ggml_tensor * wo,
1725- ggml_tensor * wo_b,
17261724 ggml_tensor * q_cur,
17271725 ggml_tensor * k_cur,
17281726 ggml_tensor * v_cur,
17291727 ggml_tensor * kq_b,
1730- int32_t n_tokens,
17311728 float kq_scale,
17321729 int il) {
1733- const auto & hparams = model. hparams ;
1730+ GGML_UNUSED (il) ;
17341731
1735- const auto & n_ctx = cparams. n_ctx ;
1732+ const auto & kq_mask = inp. kq_mask_cnv ;
17361733
1737- // const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il );
1738- const int64_t n_embd_v_gqa = hparams. n_embd_v_gqa ( il);
1734+ ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1735+ // cb(q, "q", il);
17391736
1740- const auto & kq_mask = inp.kq_mask_cnv ;
1737+ ggml_tensor * k = ggml_permute (ctx0, k_cur, 0 , 2 , 1 , 3 );
1738+ // cb(k, "k", il);
17411739
1742- const int64_t n_head = hparams. n_head (il );
1743- const int64_t n_head_kv = hparams. n_head_kv ( il);
1740+ ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
1741+ // cb(k, "v", il);
17441742
1745- // const auto & n_embd_head_k = hparams.n_embd_head_k;
1746- const auto & n_embd_head_v = hparams.n_embd_head_v ;
1743+ ggml_tensor * cur = build_attn_mha (ctx0, gf, q, k, v, kq_b, kq_mask, false , kq_scale);
17471744
1748- // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1749- const auto n_kv = n_tokens;
1745+ return cur;
1746+ }
17501747
1751- struct ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1752- // cb(q, "q", il);
1748+ ggml_tensor * llama_context::build_attn_mha (
1749+ ggml_context * ctx0,
1750+ ggml_cgraph * gf,
1751+ ggml_tensor * q,
1752+ ggml_tensor * k,
1753+ ggml_tensor * v,
1754+ ggml_tensor * kq_b,
1755+ ggml_tensor * kq_mask,
1756+ bool v_trans,
1757+ float kq_scale) {
1758+ const auto & hparams = model.hparams ;
17531759
1754- struct ggml_tensor * k = ggml_cont (ctx0, ggml_permute (ctx0, k_cur, 0 , 2 , 1 , 3 ));
1755- // cb(k, "k", il);
1760+ // const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1761+ // const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1762+
1763+ // const int64_t n_head = hparams.n_head(il);
1764+ // const int64_t n_head_kv = hparams.n_head_kv(il);
1765+
1766+ // const auto & n_embd_head_k = hparams.n_embd_head_k;
1767+ // const auto & n_embd_head_v = hparams.n_embd_head_v;
1768+
1769+ const auto n_embd_head_v = v_trans ? v->ne [1 ] : v->ne [0 ];
1770+
1771+ const auto n_tokens = q->ne [1 ];
1772+ const auto n_head = q->ne [2 ];
1773+ const auto n_kv = k->ne [1 ];
17561774
17571775 struct ggml_tensor * cur;
17581776
1759- // if (cparams.flash_attn) {
1760- if (false ) { // TODO: need to pad the batch size to a multiple of GGML_KQ_MASK_PAD
1777+ if (cparams.flash_attn && (n_kv % 256 == 0 ) && kq_b == nullptr ) {
17611778 GGML_UNUSED (model);
1762- GGML_UNUSED (n_ctx);
17631779
1764- GGML_ASSERT (kq_b == nullptr );
1780+ GGML_ASSERT (kq_b == nullptr && " Flash attention does not support KQ bias yet " );
17651781
1766- struct ggml_tensor * v = ggml_cont (ctx0, ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 ));
1767- v = ggml_reshape_3d (ctx0, v, n_embd_head_v, n_kv, n_head_kv);
1782+ if (v_trans) {
1783+ v = ggml_transpose (ctx0, v);
1784+ }
17681785
17691786 cur = ggml_flash_attn_ext (ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
17701787 hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0 .0f );
@@ -1774,7 +1791,6 @@ ggml_tensor * llama_context::build_attn(
17741791 cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
17751792 } else {
17761793 struct ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1777- // cb(kq, "kq", il);
17781794
17791795 // note: this op tends to require high floating point range
17801796 // while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1802,22 +1818,17 @@ ggml_tensor * llama_context::build_attn(
18021818 }
18031819
18041820 kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
1805- // cb(kq, "kq_soft_max_ext", il);
1806-
1807- // split cached v into n_head heads
1808- struct ggml_tensor * v = ggml_cont (ctx0, ggml_transpose (ctx0, ggml_reshape_2d (ctx0, v_cur, n_embd_v_gqa, n_tokens)));
18091821
1810- v = ggml_reshape_3d (ctx0, v, n_kv, n_embd_head_v, n_head_kv);
1811- // cb(v, "v", il);
1822+ if (!v_trans) {
1823+ // note: avoid this branch
1824+ v = ggml_cont (ctx0, ggml_transpose (ctx0, v));
1825+ }
18121826
18131827 struct ggml_tensor * kqv = ggml_mul_mat (ctx0, v, kq);
1814- // cb(kqv, "kqv", il);
18151828
18161829 struct ggml_tensor * kqv_merged = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
1817- // cb(kqv_merged, "kqv_merged", il);
18181830
18191831 cur = ggml_cont_2d (ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1820- // cb(cur, "kqv_merged_cont", il);
18211832
18221833 if (!cparams.offload_kqv ) {
18231834 // all nodes between the KV store and the attention output are run on the CPU
@@ -1827,18 +1838,6 @@ ggml_tensor * llama_context::build_attn(
18271838
18281839 ggml_build_forward_expand (gf, cur);
18291840
1830- if (wo) {
1831- cur = build_lora_mm (ctx0, wo, cur);
1832- }
1833-
1834- if (wo_b) {
1835- // cb(cur, "kqv_wo", il);
1836- }
1837-
1838- if (wo_b) {
1839- cur = ggml_add (ctx0, cur, wo_b);
1840- }
1841-
18421841 return cur;
18431842}
18441843
@@ -3274,13 +3273,10 @@ void llama_context_kv_self::build_attn_inp(
32743273ggml_tensor * llama_context_kv_self::build_attn (
32753274 ggml_context * ctx0,
32763275 ggml_cgraph * gf,
3277- ggml_tensor * wo,
3278- ggml_tensor * wo_b,
32793276 ggml_tensor * q_cur,
32803277 ggml_tensor * k_cur,
32813278 ggml_tensor * v_cur,
32823279 ggml_tensor * kq_b,
3283- int32_t n_tokens,
32843280 float kq_scale,
32853281 int il) {
32863282 const auto & hparams = model.hparams ;
@@ -3290,6 +3286,10 @@ ggml_tensor * llama_context_kv_self::build_attn(
32903286 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
32913287 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
32923288
3289+ const auto n_tokens = q_cur->ne [2 ];
3290+
3291+ const bool v_trans = !cparams.flash_attn ;
3292+
32933293 // store to KV cache
32943294 {
32953295 GGML_ASSERT (!kv_self.recurrent );
@@ -3308,7 +3308,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
33083308
33093309 struct ggml_tensor * v_cache_view = nullptr ;
33103310
3311- if (cparams. flash_attn ) {
3311+ if (!v_trans ) {
33123312 v_cache_view = ggml_view_1d (ctx0, kv_self.v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa)*kv_head);
33133313 } else {
33143314 // note: the V cache is transposed when not using flash attention
@@ -3351,117 +3351,35 @@ ggml_tensor * llama_context_kv_self::build_attn(
33513351
33523352 const auto n_kv = kv_self.n ;
33533353
3354- const int64_t n_head = hparams.n_head (il);
33553354 const int64_t n_head_kv = hparams.n_head_kv (il);
33563355
33573356 const auto & n_embd_head_k = hparams.n_embd_head_k ;
33583357 const auto & n_embd_head_v = hparams.n_embd_head_v ;
33593358
3360- struct ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
3359+ ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
33613360 // cb(q, "q", il);
33623361
3363- struct ggml_tensor * k =
3362+ ggml_tensor * k =
33643363 ggml_view_3d (ctx0, kv_self.k_l [il],
33653364 n_embd_head_k, n_kv, n_head_kv,
33663365 ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa),
33673366 ggml_row_size (kv_self.k_l [il]->type , n_embd_head_k),
33683367 0 );
33693368 // cb(k, "k", il);
33703369
3371- struct ggml_tensor * cur;
3372-
3373- if (cparams.flash_attn ) {
3374- GGML_UNUSED (model);
3375- GGML_UNUSED (n_ctx);
3376-
3377- GGML_ASSERT (kq_b == nullptr );
3378-
3379- // split cached v into n_head heads (not transposed)
3380- struct ggml_tensor * v =
3381- ggml_view_3d (ctx0, kv_self.v_l [il],
3382- n_embd_head_v, n_kv, n_head_kv,
3383- ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa),
3384- ggml_row_size (kv_self.v_l [il]->type , n_embd_head_v),
3385- 0 );
3386- // cb(v, "v", il);
3387-
3388- cur = ggml_flash_attn_ext (ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
3389- hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0 .0f );
3390-
3391- ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
3392-
3393- cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
3394- } else {
3395- struct ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
3396- // cb(kq, "kq", il);
3397-
3398- // note: this op tends to require high floating point range
3399- // while for some models F16 is enough, for others it is not, so we default to F32 here
3400- ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
3401-
3402- if (model.arch == LLM_ARCH_GROK) {
3403- // need to do the following:
3404- // multiply by attn_output_multiplyer of 0.08838834764831845
3405- // and then :
3406- // kq = 30 * tanh(kq / 30)
3407- // before the softmax below
3408-
3409- kq = ggml_tanh (ctx0, ggml_scale (ctx0, kq, 0 .08838834764831845f /30 .0f ));
3410- kq = ggml_scale (ctx0, kq, 30 );
3411- }
3412-
3413- if (hparams.attn_soft_cap ) {
3414- kq = ggml_scale (ctx0, kq, 1 .0f / hparams.f_attn_logit_softcapping );
3415- kq = ggml_tanh (ctx0, kq);
3416- kq = ggml_scale (ctx0, kq, hparams.f_attn_logit_softcapping );
3417- }
3418-
3419- if (kq_b) {
3420- kq = ggml_add (ctx0, kq, kq_b);
3421- }
3422-
3423- kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
3424- // cb(kq, "kq_soft_max_ext", il);
3425-
3426- GGML_ASSERT (kv_self.size == n_ctx);
3427-
3428- // split cached v into n_head heads
3429- struct ggml_tensor * v =
3430- ggml_view_3d (ctx0, kv_self.v_l [il],
3431- n_kv, n_embd_head_v, n_head_kv,
3432- ggml_element_size (kv_self.v_l [il])*n_ctx,
3433- ggml_element_size (kv_self.v_l [il])*n_ctx*n_embd_head_v,
3434- 0 );
3435- // cb(v, "v", il);
3436-
3437- struct ggml_tensor * kqv = ggml_mul_mat (ctx0, v, kq);
3438- // cb(kqv, "kqv", il);
3439-
3440- struct ggml_tensor * kqv_merged = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
3441- // cb(kqv_merged, "kqv_merged", il);
3442-
3443- cur = ggml_cont_2d (ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
3444- // cb(cur, "kqv_merged_cont", il);
3445-
3446- if (!cparams.offload_kqv ) {
3447- // all nodes between the KV store and the attention output are run on the CPU
3448- ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend_cpu);
3449- }
3450- }
3451-
3452- ggml_build_forward_expand (gf, cur);
3453-
3454- if (wo) {
3455- cur = build_lora_mm (ctx0, wo, cur);
3456- }
3457-
3458- if (wo_b) {
3459- // cb(cur, "kqv_wo", il);
3460- }
3370+ ggml_tensor * v = !v_trans ?
3371+ ggml_view_3d (ctx0, kv_self.v_l [il],
3372+ n_embd_head_v, n_kv, n_head_kv,
3373+ ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa),
3374+ ggml_row_size (kv_self.v_l [il]->type , n_embd_head_v),
3375+ 0 ) :
3376+ ggml_view_3d (ctx0, kv_self.v_l [il],
3377+ n_kv, n_embd_head_v, n_head_kv,
3378+ ggml_element_size (kv_self.v_l [il])*n_ctx,
3379+ ggml_element_size (kv_self.v_l [il])*n_ctx*n_embd_head_v,
3380+ 0 );
34613381
3462- if (wo_b) {
3463- cur = ggml_add (ctx0, cur, wo_b);
3464- }
3382+ struct ggml_tensor * cur = build_attn_mha (ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
34653383
34663384 return cur;
34673385}
0 commit comments