Skip to content

Commit a2ede37

Browse files
committed
Load logits directly into scores buffer
1 parent b95b0ff commit a2ede37

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

llama_cpp/llama.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,11 +442,8 @@ def eval(self, tokens: Sequence[int]):
442442
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
443443
# Save logits
444444
rows = n_tokens if self.params.logits_all else 1
445-
n_vocab = self._n_vocab
446-
cols = n_vocab
447-
logits_view = llama_cpp.llama_get_logits(self.ctx)
448-
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
449-
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
445+
cols = self._n_vocab
446+
self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
450447
# Update n_tokens
451448
self.n_tokens += n_tokens
452449

0 commit comments

Comments
 (0)