Skip to content

Commit e3c2913

Browse files
committed
fixes
1 parent d7de15a commit e3c2913

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

llama_cpp/llama.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def generate(
872872
penalize_nl=penalize_nl,
873873
idx=sample_idx,
874874
)
875-
875+
876876
sample_idx += 1
877877
if stopping_criteria is not None and stopping_criteria(
878878
self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
@@ -982,7 +982,7 @@ def embed(
982982
data: Union[List[List[float]], List[List[List[float]]]] = []
983983

984984
def decode_batch(seq_sizes: List[int]):
985-
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
985+
self._ctx.kv_cache_clear()
986986
self._ctx.decode(self._batch)
987987
self._batch.reset()
988988

@@ -1053,7 +1053,7 @@ def decode_batch(seq_sizes: List[int]):
10531053

10541054
output = data[0] if isinstance(input, str) else data
10551055

1056-
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
1056+
self._ctx.kv_cache_clear()
10571057
self.reset()
10581058

10591059
if return_count:
@@ -1350,7 +1350,7 @@ def logit_bias_processor(
13501350
text = all_text[: all_text.index(first_stop)]
13511351
finish_reason = "stop"
13521352
break
1353-
1353+
13541354
if stream:
13551355
remaining_tokens = completion_tokens[returned_tokens:]
13561356
remaining_text = self.detokenize(
@@ -2435,6 +2435,7 @@ def _create_context(
24352435
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
24362436
)
24372437
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
2438+
24382439
self.context_params.logits_all = (
24392440
logits_all if self.draft_model is None else True
24402441
) # Must be set to True for speculative decoding
@@ -2479,7 +2480,7 @@ def _create_context(
24792480

24802481
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
24812482
self.scores: npt.NDArray[np.single] = np.ndarray(
2482-
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
2483+
(n_ctx if logits_all else n_batch, self._n_vocab), dtype=np.single
24832484
)
24842485

24852486
self._batch = self._stack.enter_context(

llama_cpp/llama_cpp.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,7 @@ class llama_model_params(ctypes.Structure):
805805
# // ref: https://github.com/ggml-org/llama.cpp/pull/14363
806806
# };
807807
class llama_context_params(ctypes.Structure):
808-
"""Parameters for llama_context. NOTE: changing the default values of parameters marked as [EXPERIMENTAL]
809-
may cause crashes or incorrect results in certain configurations.
808+
"""Parameters for llama_context_params, matching the C struct for context creation.
810809
811810
Attributes:
812811
n_ctx (int): text context, 0 = from model
@@ -815,27 +814,27 @@ class llama_context_params(ctypes.Structure):
815814
n_seq_max (int): max number of sequences (i.e. distinct states for recurrent models)
816815
n_threads (int): number of threads to use for generation
817816
n_threads_batch (int): number of threads to use for batch processing
818-
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
817+
rope_scaling_type (int): RoPE scaling type, from enum llama_rope_scaling_type
819818
pooling_type (int): whether to pool (sum) embedding results by sequence id
820819
attention_type (int): attention type to use for embeddings
820+
flash_attn_type (int): when to enable Flash Attention
821821
rope_freq_base (float): RoPE base frequency, 0 = from model
822822
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
823823
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
824824
yarn_attn_factor (float): YaRN magnitude scaling factor
825825
yarn_beta_fast (float): YaRN low correction dim
826826
yarn_beta_slow (float): YaRN high correction dim
827827
yarn_orig_ctx (int): YaRN original context size
828-
defrag_thold (float): defragment the KV cache if holes/size > thold, <= 0 disabled (default)
828+
defrag_thold (float): [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
829829
cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
830-
cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
831-
type_k (int): data type for K cache
832-
type_v (int): data type for V cache
833-
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
834-
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
830+
cb_eval_user_data (ctypes.c_void_p): user data for cb_eval
831+
type_k (int): data type for K cache [EXPERIMENTAL]
832+
type_v (int): data type for V cache [EXPERIMENTAL]
833+
abort_callback (ggml_abort_callback): abort callback for llama_decode
834+
abort_callback_data (ctypes.c_void_p): user data for abort_callback
835835
embeddings (bool): if true, extract embeddings (together with logits)
836-
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
837-
flash_attn (bool): whether to use flash attention
838-
no_perf (bool): whether to measure performance timings
836+
offload_kqv (bool): offload the KQV ops (including the KV cache) to GPU
837+
no_perf (bool): measure performance timings
839838
op_offload (bool): offload host tensor operations to device
840839
swa_full (bool): use full-size SWA cache
841840
kv_unified (bool): use a unified buffer across the input sequences when computing the attention
@@ -851,6 +850,7 @@ class llama_context_params(ctypes.Structure):
851850
rope_scaling_type: int
852851
pooling_type: int
853852
attention_type: int
853+
flash_attn_type: int
854854
rope_freq_base: float
855855
rope_freq_scale: float
856856
yarn_ext_factor: float
@@ -867,7 +867,6 @@ class llama_context_params(ctypes.Structure):
867867
abort_callback_data: ctypes.c_void_p
868868
embeddings: bool
869869
offload_kqv: bool
870-
flash_attn: bool
871870
no_perf: bool
872871
op_offload: bool
873872
swa_full: bool
@@ -880,9 +879,10 @@ class llama_context_params(ctypes.Structure):
880879
("n_seq_max", ctypes.c_uint32),
881880
("n_threads", ctypes.c_int32),
882881
("n_threads_batch", ctypes.c_int32),
883-
("rope_scaling_type", ctypes.c_int),
884-
("pooling_type", ctypes.c_int),
885-
("attention_type", ctypes.c_int),
882+
("rope_scaling_type", ctypes.c_int), # enum llama_rope_scaling_type
883+
("pooling_type", ctypes.c_int), # enum llama_pooling_type
884+
("attention_type", ctypes.c_int), # enum llama_attention_type
885+
("flash_attn_type", ctypes.c_int), # enum llama_flash_attn_type
886886
("rope_freq_base", ctypes.c_float),
887887
("rope_freq_scale", ctypes.c_float),
888888
("yarn_ext_factor", ctypes.c_float),
@@ -893,13 +893,13 @@ class llama_context_params(ctypes.Structure):
893893
("defrag_thold", ctypes.c_float),
894894
("cb_eval", ggml_backend_sched_eval_callback),
895895
("cb_eval_user_data", ctypes.c_void_p),
896-
("type_k", ctypes.c_int),
897-
("type_v", ctypes.c_int),
896+
("type_k", ctypes.c_int), # enum ggml_type
897+
("type_v", ctypes.c_int), # enum ggml_type
898898
("abort_callback", ggml_abort_callback),
899899
("abort_callback_data", ctypes.c_void_p),
900+
# Booleans at the end for alignment
900901
("embeddings", ctypes.c_bool),
901902
("offload_kqv", ctypes.c_bool),
902-
("flash_attn", ctypes.c_bool),
903903
("no_perf", ctypes.c_bool),
904904
("op_offload", ctypes.c_bool),
905905
("swa_full", ctypes.c_bool),

0 commit comments

Comments
 (0)