File tree Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const {
6464 return kv_self.pos_max ();
6565}
6666
67+ uint32_t llama_context::get_ctx_padding (const llama_cparams & cparams) const {
68+ return kv_self.get_padding (cparams);
69+ }
70+
6771// TODO: improve
6872void llama_context::reset () {
6973 inp_tokens = nullptr ;
Original file line number Diff line number Diff line change @@ -84,8 +84,11 @@ struct llama_context {
8484 ggml_cgraph * graph,
8585 bool batched);
8686
87+ // max token position across all sequences in the current context
8788 llama_pos pos_max () const ;
8889
90+ uint32_t get_ctx_padding (const llama_cparams & cparams) const ;
91+
8992 void reset ();
9093
9194 void prepare_k_shift ();
Original file line number Diff line number Diff line change @@ -7820,6 +7820,7 @@ static int llama_decode_impl(
78207820 }
78217821
78227822 // temporary allocate memory for the input batch if needed
7823+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
78237824 llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.pos_max () + 1 );
78247825
78257826 const llama_batch & batch = batch_allocr.batch ;
@@ -8154,6 +8155,7 @@ static int llama_encode_impl(
81548155 }
81558156
81568157 // temporary allocate memory for the input batch if needed
8158+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
81578159 llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.pos_max () + 1 );
81588160
81598161 const llama_batch & batch = batch_allocr.batch ;
@@ -8660,7 +8662,7 @@ struct llama_context * llama_init_from_model(
86608662 cparams.rope_freq_scale = params.rope_freq_scale == 0 .0f ? hparams.rope_freq_scale_train : params.rope_freq_scale ;
86618663
86628664 // this is necessary due to kv_self.n being padded later during inference
8663- cparams.n_ctx = GGML_PAD (cparams.n_ctx , ctx->kv_self . get_padding (cparams));
8665+ cparams.n_ctx = GGML_PAD (cparams.n_ctx , ctx->get_ctx_padding (cparams));
86648666
86658667 // with causal attention, the batch size is limited by the context size
86668668 cparams.n_batch = hparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
You can’t perform that action at this time.
0 commit comments