1010#include < stdexcept>
1111#include < cinttypes>
1212
13+ //
14+ // helpers
15+ //
16+
17+ static int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
18+ // TODO move to hparams if a T5 variant appears that uses a different value
19+ const int64_t max_distance = 128 ;
20+
21+ if (bidirectional) {
22+ n_buckets >>= 1 ;
23+ }
24+
25+ const int64_t max_exact = n_buckets >> 1 ;
26+
27+ int32_t relative_position = x - y;
28+ int32_t relative_bucket = 0 ;
29+
30+ if (bidirectional) {
31+ relative_bucket += (relative_position > 0 ) * n_buckets;
32+ relative_position = abs (relative_position);
33+ } else {
34+ relative_position = -std::min<int32_t >(relative_position, 0 );
35+ }
36+
37+ int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
38+ relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
39+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
40+
41+ return relative_bucket;
42+ }
43+
1344//
1445// llama_context
1546//
1647
1748llama_context::llama_context (
1849 const llama_model & model,
19- const llama_context_params & params,
50+ llama_context_params params,
2051 llama_graph_type gtype) :
2152 llama_graph_i(gtype),
2253 model(model) {
23- LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
54+ LLAMA_LOG_INFO (" %s: constructing llama_context, gtype = %d \n " , __func__, gtype );
2455
2556 t_start_us = model.t_start_us ;
2657 t_load_us = model.t_load_us ;
2758
59+ switch (gtype) {
60+ case LLAMA_GRAPH_TYPE_DEFAULT:
61+ case LLAMA_GRAPH_TYPE_DECODER:
62+ {
63+ } break ;
64+ case LLAMA_GRAPH_TYPE_ENCODER:
65+ {
66+ params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL;
67+ params.embeddings = true ;
68+ } break ;
69+ }
70+
2871 const auto & hparams = model.hparams ;
2972
3073 cparams.n_seq_max = std::max (1u , params.n_seq_max );
@@ -45,20 +88,6 @@ llama_context::llama_context(
4588 cparams.rope_freq_base = params.rope_freq_base == 0 .0f ? hparams.rope_freq_base_train : params.rope_freq_base ;
4689 cparams.rope_freq_scale = params.rope_freq_scale == 0 .0f ? hparams.rope_freq_scale_train : params.rope_freq_scale ;
4790
48- // with causal attention, the batch size is limited by the context size
49- cparams.n_batch = hparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
50-
51- // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
52- // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
53- // ref: https://github.com/ggerganov/llama.cpp/pull/5021
54- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
55- if (cparams.n_batch < GGML_KQ_MASK_PAD) {
56- LLAMA_LOG_WARN (" %s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n " , __func__, GGML_KQ_MASK_PAD);
57- cparams.n_batch = GGML_KQ_MASK_PAD;
58- }
59-
60- cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
61-
6291 cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
6392 hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
6493 hparams.n_ctx_train ;
@@ -95,13 +124,28 @@ llama_context::llama_context(
95124 cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
96125 }
97126
127+ // with causal attention, the batch size is limited by the context size
128+ cparams.n_batch = cparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
129+
130+ // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
131+ // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
132+ // ref: https://github.com/ggerganov/llama.cpp/pull/5021
133+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
134+ if (cparams.n_batch < GGML_KQ_MASK_PAD) {
135+ LLAMA_LOG_WARN (" %s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n " , __func__, GGML_KQ_MASK_PAD);
136+ cparams.n_batch = GGML_KQ_MASK_PAD;
137+ }
138+
139+ cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
140+
98141 const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
99142
100143 LLAMA_LOG_INFO (" %s: n_seq_max = %u\n " , __func__, cparams.n_seq_max );
101144 LLAMA_LOG_INFO (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
102145 LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq);
103146 LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
104147 LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
148+ LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
105149 LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
106150 LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
107151 LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
@@ -1207,6 +1251,23 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
12071251 }
12081252 }
12091253
1254+ if (inp.pos_bucket ) {
1255+ const int64_t n_tokens = ubatch.n_tokens ;
1256+
1257+ GGML_ASSERT (ggml_backend_buffer_is_host (inp.pos_bucket ->buffer ));
1258+ GGML_ASSERT (!ubatch.equal_seqs ); // TODO: use ubatch.n_seqs instead of failing
1259+
1260+ int32_t * data = (int32_t *) inp.pos_bucket ->data ;
1261+
1262+ for (int h = 0 ; h < 1 ; ++h) {
1263+ for (int j = 0 ; j < n_tokens; ++j) {
1264+ for (int i = 0 ; i < n_tokens; ++i) {
1265+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , true );
1266+ }
1267+ }
1268+ }
1269+ }
1270+
12101271 GGML_ASSERT (
12111272 // (!a || b) is a logical implication (a -> b)
12121273 // !hparams.causal_attn -> !cparams.causal_attn
@@ -1604,6 +1665,15 @@ ggml_tensor * llama_context::build_inp_pos(
16041665 return inp.pos ;
16051666}
16061667
1668+ ggml_tensor * llama_context::build_inp_pos_bucket (
1669+ ggml_context * ctx0,
1670+ int32_t n_tokens) {
1671+ inp.pos_bucket = ggml_new_tensor_2d (ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1672+ ggml_set_input (inp.pos_bucket );
1673+
1674+ return inp.pos_bucket ;
1675+ }
1676+
16071677ggml_tensor * llama_context::build_inp_out_ids (
16081678 ggml_context * ctx0) {
16091679 const int32_t n_out_ids = n_outputs;
@@ -1656,6 +1726,7 @@ ggml_tensor * llama_context::build_attn(
16561726 ggml_tensor * q_cur,
16571727 ggml_tensor * k_cur,
16581728 ggml_tensor * v_cur,
1729+ ggml_tensor * kq_b,
16591730 int32_t n_tokens,
16601731 float kq_scale,
16611732 int il) {
@@ -1690,6 +1761,8 @@ ggml_tensor * llama_context::build_attn(
16901761 GGML_UNUSED (model);
16911762 GGML_UNUSED (n_ctx);
16921763
1764+ GGML_ASSERT (kq_b == nullptr );
1765+
16931766 struct ggml_tensor * v = ggml_cont (ctx0, ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 ));
16941767 v = ggml_reshape_3d (ctx0, v, n_embd_head_v, n_kv, n_head_kv);
16951768
@@ -1720,10 +1793,14 @@ ggml_tensor * llama_context::build_attn(
17201793
17211794 if (hparams.attn_soft_cap ) {
17221795 kq = ggml_scale (ctx0, kq, 1 .0f / hparams.f_attn_logit_softcapping );
1723- kq = ggml_tanh (ctx0, kq);
1796+ kq = ggml_tanh (ctx0, kq);
17241797 kq = ggml_scale (ctx0, kq, hparams.f_attn_logit_softcapping );
17251798 }
17261799
1800+ if (kq_b) {
1801+ kq = ggml_add (ctx0, kq, kq_b);
1802+ }
1803+
17271804 kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
17281805 // cb(kq, "kq_soft_max_ext", il);
17291806
@@ -2281,7 +2358,7 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_
22812358
22822359llama_context_kv_self::llama_context_kv_self (
22832360 const llama_model & model,
2284- const llama_context_params & params,
2361+ llama_context_params params,
22852362 llama_graph_type gtype) :
22862363 llama_context(model, params, gtype),
22872364 kv_self(model.hparams) {
@@ -3053,53 +3130,19 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
30533130 }
30543131 }
30553132
3056- if (inp_pos_bucket ) {
3133+ if (inp. self_pos_bucket ) {
30573134 const int64_t n_tokens = ubatch.n_tokens ;
30583135
3059- GGML_ASSERT (ggml_backend_buffer_is_host (inp_pos_bucket ->buffer ));
3136+ GGML_ASSERT (ggml_backend_buffer_is_host (inp. self_pos_bucket ->buffer ));
30603137 GGML_ASSERT (!ubatch.equal_seqs ); // TODO: use ubatch.n_seqs instead of failing
30613138
3062- static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
3063- // TODO move to hparams if a T5 variant appears that uses a different value
3064- const int64_t max_distance = 128 ;
3065-
3066- if (bidirectional) {
3067- n_buckets >>= 1 ;
3068- }
3139+ int32_t * data = (int32_t *) inp.self_pos_bucket ->data ;
30693140
3070- const int64_t max_exact = n_buckets >> 1 ;
3071-
3072- int32_t relative_position = x - y;
3073- int32_t relative_bucket = 0 ;
3074- if (bidirectional) {
3075- relative_bucket += (relative_position > 0 ) * n_buckets;
3076- relative_position = abs (relative_position);
3077- } else {
3078- relative_position = -std::min<int32_t >(relative_position, 0 );
3079- }
3080- int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
3081- relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
3082- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
3083- return relative_bucket;
3084- };
3085-
3086- int32_t * data = (int32_t *) inp_pos_bucket->data ;
3087-
3088- if (!is_encoding) {
3089- const int64_t n_kv = kv_self.n ;
3090- for (int h = 0 ; h < 1 ; ++h) {
3091- for (int j = 0 ; j < n_tokens; ++j) {
3092- for (int i = 0 ; i < n_kv; ++i) {
3093- data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
3094- }
3095- }
3096- }
3097- } else {
3098- for (int h = 0 ; h < 1 ; ++h) {
3099- for (int j = 0 ; j < n_tokens; ++j) {
3100- for (int i = 0 ; i < n_tokens; ++i) {
3101- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
3102- }
3141+ const int64_t n_kv = kv_self.n ;
3142+ for (int h = 0 ; h < 1 ; ++h) {
3143+ for (int j = 0 ; j < n_tokens; ++j) {
3144+ for (int i = 0 ; i < n_kv; ++i) {
3145+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , false );
31033146 }
31043147 }
31053148 }
@@ -3146,7 +3189,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
31463189
31473190ggml_cgraph * llama_context_kv_self::graph_init () {
31483191 inp_embd_enc = nullptr ;
3149- inp_pos_bucket = nullptr ;
31503192 inp_kq_mask_cross = nullptr ;
31513193
31523194 inp = {};
@@ -3161,6 +3203,17 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0)
31613203 return inp.self_k_shift ;
31623204}
31633205
3206+ ggml_tensor * llama_context_kv_self::build_inp_pos_bucket (
3207+ ggml_context * ctx0,
3208+ int32_t n_tokens) {
3209+ const auto n_kv = kv_self.n ;
3210+
3211+ inp.self_pos_bucket = ggml_new_tensor_2d (ctx0, GGML_TYPE_I32, n_kv, n_tokens);
3212+ ggml_set_input (inp.self_pos_bucket );
3213+
3214+ return inp.self_pos_bucket ;
3215+ }
3216+
31643217void llama_context_kv_self::build_attn_inp (
31653218 ggml_context * ctx0,
31663219 int32_t n_tokens,
@@ -3199,6 +3252,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
31993252 ggml_tensor * q_cur,
32003253 ggml_tensor * k_cur,
32013254 ggml_tensor * v_cur,
3255+ ggml_tensor * kq_b,
32023256 int32_t n_tokens,
32033257 float kq_scale,
32043258 int il) {
@@ -3293,6 +3347,8 @@ ggml_tensor * llama_context_kv_self::build_attn(
32933347 GGML_UNUSED (model);
32943348 GGML_UNUSED (n_ctx);
32953349
3350+ GGML_ASSERT (kq_b == nullptr );
3351+
32963352 // split cached v into n_head heads (not transposed)
32973353 struct ggml_tensor * v =
32983354 ggml_view_3d (ctx0, kv_self.v_l [il],
@@ -3329,10 +3385,14 @@ ggml_tensor * llama_context_kv_self::build_attn(
33293385
33303386 if (hparams.attn_soft_cap ) {
33313387 kq = ggml_scale (ctx0, kq, 1 .0f / hparams.f_attn_logit_softcapping );
3332- kq = ggml_tanh (ctx0, kq);
3388+ kq = ggml_tanh (ctx0, kq);
33333389 kq = ggml_scale (ctx0, kq, hparams.f_attn_logit_softcapping );
33343390 }
33353391
3392+ if (kq_b) {
3393+ kq = ggml_add (ctx0, kq, kq_b);
3394+ }
3395+
33363396 kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
33373397 // cb(kq, "kq_soft_max_ext", il);
33383398
@@ -3753,7 +3813,7 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq
37533813
37543814llama_context_recurrent::llama_context_recurrent (
37553815 const llama_model & model,
3756- const llama_context_params & params,
3816+ llama_context_params params,
37573817 llama_graph_type gtype) :
37583818 llama_context(model, params, gtype),
37593819 kv_self(model.hparams) {
@@ -4629,7 +4689,7 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s
46294689
46304690llama_context_enc_dec::llama_context_enc_dec (
46314691 const llama_model & model,
4632- const llama_context_params & params) :
4692+ llama_context_params params) :
46334693 llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER),
46344694 ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) {
46354695 LLAMA_LOG_INFO (" %s: constructing llama_context_enc_dec\n " , __func__);
0 commit comments