@@ -313,7 +313,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
313313 }
314314}
315315
316- void llm_graph_input_attn_base ::set_input (const llama_ubatch * ubatch) {
316+ void llm_graph_input_attn_no_cache ::set_input (const llama_ubatch * ubatch) {
317317 if (kq_mask) {
318318 if (cparams.causal_attn ) {
319319 const int64_t n_kv = ubatch->n_tokens ;
@@ -400,7 +400,7 @@ void llm_graph_input_attn_base::set_input(const llama_ubatch * ubatch) {
400400 }
401401}
402402
403- void llm_graph_input_attn_kv_self ::set_input (const llama_ubatch * ubatch) {
403+ void llm_graph_input_attn_kv_unified ::set_input (const llama_ubatch * ubatch) {
404404 if (self_kq_mask || self_kq_mask_swa) {
405405 // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406406 if (cparams.causal_attn ) {
@@ -523,9 +523,7 @@ void llm_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) {
523523 }
524524}
525525
526- void llm_graph_input_attn_dec::set_input (const llama_ubatch * ubatch) {
527- inp_kv_self->set_input (ubatch);
528-
526+ void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
529527 if (cross_kq_mask) {
530528 const int64_t n_enc = cross_kq_mask->ne [0 ];
531529 const int64_t n_tokens = ubatch->n_tokens ;
@@ -1066,7 +1064,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
10661064 auto & cur = inp->s_copy ;
10671065
10681066 cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1069- // cb(cur, "inp_s_copy", -1);
10701067 ggml_set_input (cur);
10711068
10721069 res->add_input (std::move (inp));
@@ -1084,7 +1081,6 @@ ggml_tensor * llm_graph_context::build_inp_s_mask() const {
10841081 auto & cur = inp->s_mask ;
10851082
10861083 cur = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, 1 , n_kv);
1087- // cb(cur, "inp_s_mask", -1);
10881084 ggml_set_input (cur);
10891085
10901086 res->add_input (std::move (inp));
@@ -1151,15 +1147,11 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
11511147 cb (pos_bucket_1d, " pos_bucket_1d" , -1 );
11521148
11531149 ggml_tensor * pos_bias = ggml_get_rows (ctx0, attn_rel_b, pos_bucket_1d);
1154- cb (pos_bias, " pos_bias" , -1 );
11551150
11561151 pos_bias = ggml_reshape_3d (ctx0, pos_bias, pos_bias->ne [0 ], pos_bucket->ne [0 ], pos_bucket->ne [1 ]);
1157- cb (pos_bias, " pos_bias" , -1 );
1152+ pos_bias = ggml_permute (ctx0, pos_bias, 2 , 0 , 1 , 3 );
1153+ pos_bias = ggml_cont (ctx0, pos_bias);
11581154
1159- pos_bias = ggml_permute (ctx0, pos_bias, 2 , 0 , 1 , 3 );
1160- cb (pos_bias, " pos_bias" , -1 );
1161-
1162- pos_bias = ggml_cont (ctx0, pos_bias);
11631155 cb (pos_bias, " pos_bias" , -1 );
11641156
11651157 return pos_bias;
@@ -1257,26 +1249,21 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12571249 return cur;
12581250}
12591251
1260- llm_graph_input_attn_base * llm_graph_context::build_attn_inp_base (
1261- bool causal,
1262- bool swa) const {
1263- auto inp = std::make_unique<llm_graph_input_attn_base>(hparams, cparams);
1252+ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache () const {
1253+ auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
12641254
12651255 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1266- GGML_UNUSED (causal);
1267- GGML_UNUSED (swa);
1268-
12691256 inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
12701257 // cb(inp_kq_mask, "KQ_mask", -1);
12711258 ggml_set_input (inp->kq_mask );
12721259
12731260 inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
12741261
1275- return (llm_graph_input_attn_base *) res->add_input (std::move (inp));
1262+ return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
12761263}
12771264
12781265ggml_tensor * llm_graph_context::build_attn (
1279- llm_graph_input_attn_base * inp,
1266+ llm_graph_input_attn_no_cache * inp,
12801267 ggml_cgraph * gf,
12811268 ggml_tensor * wo,
12821269 ggml_tensor * wo_b,
@@ -1324,12 +1311,12 @@ ggml_tensor * llm_graph_context::build_attn(
13241311 return cur;
13251312}
13261313
1327- llm_graph_input_attn_kv_self * llm_graph_context::build_attn_inp_kv_self (
1314+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (
13281315 bool causal,
13291316 bool swa) const {
13301317 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
13311318
1332- auto inp = std::make_unique<llm_graph_input_attn_kv_self >(hparams, cparams, kv_self);
1319+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, kv_self);
13331320
13341321 const auto n_kv = kv_self->n ;
13351322
@@ -1353,11 +1340,11 @@ llm_graph_input_attn_kv_self * llm_graph_context::build_attn_inp_kv_self(
13531340 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
13541341 }
13551342
1356- return (llm_graph_input_attn_kv_self *) res->add_input (std::move (inp));
1343+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
13571344}
13581345
13591346ggml_tensor * llm_graph_context::build_attn (
1360- llm_graph_input_attn_kv_self * inp,
1347+ llm_graph_input_attn_kv_unified * inp,
13611348 ggml_cgraph * gf,
13621349 ggml_tensor * wo,
13631350 ggml_tensor * wo_b,
@@ -1490,12 +1477,8 @@ ggml_tensor * llm_graph_context::build_attn(
14901477 return cur;
14911478}
14921479
1493- llm_graph_input_attn_dec * llm_graph_context::build_attn_inp_dec (
1494- bool causal,
1495- bool swa) const {
1496- auto * inp_kv_self = build_attn_inp_kv_self (causal, swa);
1497-
1498- auto inp = std::make_unique<llm_graph_input_attn_dec>(inp_kv_self, cross);
1480+ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross () const {
1481+ auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
14991482
15001483 const int32_t n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
15011484
@@ -1504,11 +1487,11 @@ llm_graph_input_attn_dec * llm_graph_context::build_attn_inp_dec(
15041487
15051488 inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->cross_kq_mask , GGML_TYPE_F16) : inp->cross_kq_mask ;
15061489
1507- return (llm_graph_input_attn_dec *) res->add_input (std::move (inp));
1490+ return (llm_graph_input_attn_cross *) res->add_input (std::move (inp));
15081491}
15091492
15101493ggml_tensor * llm_graph_context::build_attn (
1511- llm_graph_input_attn_dec * inp,
1494+ llm_graph_input_attn_cross * inp,
15121495 ggml_cgraph * gf,
15131496 ggml_tensor * wo,
15141497 ggml_tensor * wo_b,
0 commit comments