@@ -7807,8 +7807,6 @@ static int llama_decode_impl(
78077807 uint32_t n_outputs = 0 ;
78087808 uint32_t n_outputs_prev = 0 ;
78097809
7810- const auto n_ubatch = cparams.n_ubatch ;
7811-
78127810 // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
78137811 const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
78147812
@@ -7832,27 +7830,19 @@ static int llama_decode_impl(
78327830 return -2 ;
78337831 };
78347832
7835- auto & kv_self = lctx.kv_self ;
7836- llama_kv_slot_restorer kv_slot_restorer (kv_self);
7833+ const bool logits_all = n_outputs == n_tokens_all;
7834+
7835+ // auto & kv_self = lctx.kv_self;
7836+ // llama_kv_slot_restorer kv_slot_restorer(kv_self);
7837+
7838+ // lctx.sbatch.from_batch(batch, n_embd,
7839+ // /* simple_split */ !kv_self.recurrent,
7840+ // /* logits_all */ logits_all);
78377841
7838- lctx.sbatch .from_batch (batch, n_embd,
7839- /* simple_split */ !kv_self.recurrent ,
7840- /* logits_all */ n_outputs == n_tokens_all);
7842+ auto batch_manager = lctx.prepare_batch (batch, logits_all);
78417843
78427844 while (lctx.sbatch .n_tokens > 0 ) {
7843- llama_ubatch ubatch;
7844- if (kv_self.recurrent ) {
7845- if (embd_pooled) {
7846- // Pooled embeddings cannot be split across ubatches (yet)
7847- ubatch = lctx.sbatch .split_seq (n_ubatch);
7848- } else {
7849- // recurrent model architectures are easier to implement
7850- // with equal-length sequences
7851- ubatch = lctx.sbatch .split_equal (n_ubatch);
7852- }
7853- } else {
7854- ubatch = lctx.sbatch .split_simple (n_ubatch);
7855- }
7845+ llama_ubatch ubatch = batch_manager->next ();
78567846
78577847 const uint32_t n_tokens = ubatch.n_tokens ;
78587848
@@ -7873,32 +7863,10 @@ static int llama_decode_impl(
78737863 lctx.n_outputs = n_outputs_new;
78747864 }
78757865
7876- lctx.prepare_decode (ubatch);
7877-
7878- // non-causal masks do not use the KV cache
7879- if (hparams.causal_attn ) {
7880- llama_kv_self_update (&lctx);
7881-
7882- // if we have enough unused cells before the current head ->
7883- // better to start searching from the beginning of the cache, hoping to fill it
7884- if (kv_self.head > kv_self.used + 2 *n_tokens) {
7885- kv_self.head = 0 ;
7886- }
7887-
7888- const auto slot_info = kv_self.find_slot (ubatch);
7889- if (!slot_info) {
7890- return 1 ;
7891- }
7892- kv_slot_restorer.save (slot_info);
7893-
7894- if (!kv_self.recurrent ) {
7895- // a heuristic, to avoid attending the full cache if it is not yet utilized
7896- // after enough generations, the benefit from this heuristic disappears
7897- // if we start defragmenting the cache, the benefit from this will be more important
7898- const uint32_t pad = kv_self.get_padding (cparams);
7899- kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (kv_self.cell_max (), pad)));
7900- // kv_self.n = llama_kv_cache_cell_max(kv_self);
7901- }
7866+ if (!batch_manager->prepare ()) {
7867+ LLAMA_LOG_ERROR (" %s: failed to prepare ubatch\n " , __func__);
7868+ batch_manager->restore ();
7869+ return -3 ;
79027870 }
79037871
79047872 // reserve a worst case graph if needed
@@ -7963,7 +7931,7 @@ static int llama_decode_impl(
79637931
79647932 const auto compute_status = lctx.compute_graph (gf, n_tokens > 1 );
79657933 if (compute_status != GGML_STATUS_SUCCESS) {
7966- kv_slot_restorer. restore (kv_self );
7934+ batch_manager-> restore ();
79677935 switch (compute_status) {
79687936 case GGML_STATUS_ABORTED:
79697937 return 2 ;
@@ -7975,15 +7943,7 @@ static int llama_decode_impl(
79757943 }
79767944 }
79777945
7978- // update the kv ring buffer
7979- {
7980- kv_self.head += n_tokens;
7981-
7982- // Ensure kv cache head points to a valid index.
7983- if (kv_self.head >= kv_self.size ) {
7984- kv_self.head = 0 ;
7985- }
7986- }
7946+ batch_manager->update ();
79877947
79887948 // plot the computation graph in dot format (for debugging purposes)
79897949 // if (n_past%100 == 0) {
@@ -8061,6 +8021,7 @@ static int llama_decode_impl(
80618021 }
80628022 }
80638023 }
8024+
80648025 n_outputs_prev += lctx.n_outputs ;
80658026 }
80668027
@@ -8089,17 +8050,7 @@ static int llama_decode_impl(
80898050 // wait for the computation to finish (automatically done when obtaining the model output)
80908051 // llama_synchronize(&lctx);
80918052
8092- // decide if we need to defrag the kv cache
8093- if (cparams.causal_attn && cparams.defrag_thold >= 0 .0f ) {
8094- const float fragmentation = kv_self.n >= 128 ? 1 .0f - float (kv_self.used )/float (kv_self.n ) : 0 .0f ;
8095-
8096- // queue defragmentation for next llama_kv_cache_update
8097- if (fragmentation > cparams.defrag_thold ) {
8098- // LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
8099-
8100- kv_self.defrag ();
8101- }
8102- }
8053+ batch_manager->finalize ();
81038054
81048055 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
81058056 // overlap with device computation.
@@ -8178,7 +8129,7 @@ static int llama_encode_impl(
81788129 lctx.inp_embd_enc = NULL ;
81798130 lctx.n_outputs = n_tokens;
81808131
8181- lctx. prepare_decode (ubatch);
8132+ // batch_manager->prepare (ubatch);
81828133
81838134 // reserve a worst case graph if needed
81848135 // TODO: extract to a function
0 commit comments