@@ -76,13 +76,13 @@ llama_context::llama_context(
7676 }
7777
7878 if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
79- cparams.causal_attn = hparams.causal_attn ;
79+ cparams.attn_type = hparams.causal_attn ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL ;
8080 } else {
81- cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL ;
81+ cparams.attn_type = params.attention_type ;
8282 }
8383
8484 // with causal attention, the batch size is limited by the context size
85- cparams.n_batch = cparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
85+ cparams.n_batch = cparams.use_past_tokens () ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
8686
8787 // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
8888 // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
@@ -102,7 +102,7 @@ llama_context::llama_context(
102102 LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq);
103103 LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
104104 LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
105- LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
105+ LLAMA_LOG_INFO (" %s: attn_type = %d\n " , __func__, cparams.attn_type );
106106 LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
107107 LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
108108 LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
@@ -966,10 +966,10 @@ void llama_context::set_embeddings(bool value) {
966966 cparams.embeddings = value;
967967}
968968
969- void llama_context::set_causal_attn ( bool value) {
969+ void llama_context::set_attn_type ( enum llama_attention_type value) {
970970 LLAMA_LOG_DEBUG (" %s: value = %d\n " , __func__, value);
971971
972- cparams.causal_attn = value;
972+ cparams.attn_type = value;
973973}
974974
975975void llama_context::set_warmup (bool value) {
@@ -1074,12 +1074,12 @@ int llama_context::encode(llama_batch & inp_batch) {
10741074 ggml_backend_sched_reset (sched.get ());
10751075 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
10761076
1077- const auto causal_attn_org = cparams.causal_attn ;
1077+ const auto attn_type_org = cparams.attn_type ;
10781078
10791079 // always use non-causal attention for encoder graphs
10801080 // TODO: this is a tmp solution until we have a proper way to support enc-dec models
10811081 // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1082- cparams.causal_attn = false ;
1082+ cparams.attn_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL ;
10831083
10841084 auto * gf = graph_init ();
10851085 auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
@@ -1088,7 +1088,7 @@ int llama_context::encode(llama_batch & inp_batch) {
10881088
10891089 res->set_inputs (&ubatch);
10901090
1091- cparams.causal_attn = causal_attn_org ;
1091+ cparams.attn_type = attn_type_org ;
10921092
10931093 const auto compute_status = graph_compute (gf, n_tokens > 1 );
10941094 switch (compute_status) {
@@ -1242,7 +1242,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12421242
12431243 GGML_ASSERT (n_tokens_all <= cparams.n_batch );
12441244
1245- GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
1245+ GGML_ASSERT ((! cparams.use_past_tokens () || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
12461246
12471247 if (t_compute_start_us == 0 ) {
12481248 t_compute_start_us = ggml_time_us ();
@@ -1495,7 +1495,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14951495 // synchronize();
14961496
14971497 // decide if we need to defrag the kv cache
1498- if (cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1498+ if (cparams.use_past_tokens () && cparams.defrag_thold > 0 .0f ) {
14991499 // - do not defrag small contexts (i.e. < 2048 tokens)
15001500 // - count the padding towards the number of used tokens
15011501 const float fragmentation = kv_self->n >= 2048 ? std::max (0 .0f , 1 .0f - float (kv_self->used + kv_self->get_padding (cparams))/float (kv_self->n )) : 0 .0f ;
@@ -2410,8 +2410,12 @@ void llama_set_embeddings(llama_context * ctx, bool embeddings) {
24102410 ctx->set_embeddings (embeddings);
24112411}
24122412
2413+ void llama_set_attn_type (llama_context * ctx, llama_attention_type type) {
2414+ ctx->set_attn_type (type);
2415+ }
2416+
24132417void llama_set_causal_attn (llama_context * ctx, bool causal_attn) {
2414- ctx->set_causal_attn (causal_attn);
2418+ ctx->set_attn_type (causal_attn ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL );
24152419}
24162420
24172421void llama_set_warmup (llama_context * ctx, bool warmup) {
0 commit comments