Skip to content

Commit 7354e16

Browse files
committed
Sync llama: introduce support for model-embedded sampling parameters
1 parent a81b6dc commit 7354e16

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,17 @@ def __init__(
273273
if isinstance(v, bool):
274274
self._kv_overrides_array[
275275
i
276-
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
276+
].tag = llama_cpp.LlamaModelKVOverrideType.LLAMA_KV_OVERRIDE_TYPE_BOOL.value
277277
self._kv_overrides_array[i].value.val_bool = v
278278
elif isinstance(v, int):
279279
self._kv_overrides_array[
280280
i
281-
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
281+
].tag = llama_cpp.LlamaModelKVOverrideType.LLAMA_KV_OVERRIDE_TYPE_INT.value
282282
self._kv_overrides_array[i].value.val_i64 = v
283283
elif isinstance(v, float):
284284
self._kv_overrides_array[
285285
i
286-
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
286+
].tag = llama_cpp.LlamaModelKVOverrideType.LLAMA_KV_OVERRIDE_TYPE_FLOAT.value
287287
self._kv_overrides_array[i].value.val_f64 = v
288288
elif isinstance(v, str): # type: ignore
289289
v_bytes = v.encode("utf-8")
@@ -292,7 +292,7 @@ def __init__(
292292
v_bytes = v_bytes.ljust(128, b"\0")
293293
self._kv_overrides_array[
294294
i
295-
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
295+
].tag = llama_cpp.LlamaModelKVOverrideType.LLAMA_KV_OVERRIDE_TYPE_STR.value
296296
# copy min(v_bytes, 128) to str_value
297297
address = typing.cast(
298298
int,

llama_cpp/llama_cpp.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -584,10 +584,40 @@ class llama_batch(ctypes.Structure):
584584
# LLAMA_KV_OVERRIDE_TYPE_BOOL,
585585
# LLAMA_KV_OVERRIDE_TYPE_STR,
586586
# };
587-
LLAMA_KV_OVERRIDE_TYPE_INT = 0
588-
LLAMA_KV_OVERRIDE_TYPE_FLOAT = 1
589-
LLAMA_KV_OVERRIDE_TYPE_BOOL = 2
590-
LLAMA_KV_OVERRIDE_TYPE_STR = 3
587+
class LlamaModelKVOverrideType(enum.IntEnum):
588+
LLAMA_KV_OVERRIDE_TYPE_INT = 0
589+
LLAMA_KV_OVERRIDE_TYPE_FLOAT = 1
590+
LLAMA_KV_OVERRIDE_TYPE_BOOL = 2
591+
LLAMA_KV_OVERRIDE_TYPE_STR = 3
592+
593+
594+
# enum llama_model_meta_key {
595+
# LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
596+
# LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
597+
# LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
598+
# LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
599+
# LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
600+
# LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
601+
# LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
602+
# LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
603+
# LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
604+
# LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
605+
# LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
606+
# LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
607+
# };
608+
class LlamaModelMetaKey(enum.IntEnum):
609+
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE = 0
610+
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K = 1
611+
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P = 2
612+
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P = 3
613+
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY = 4
614+
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD = 5
615+
LLAMA_MODEL_META_KEY_SAMPLING_TEMP = 6
616+
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N = 7
617+
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT = 8
618+
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT = 9
619+
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU = 10
620+
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA = 11
591621

592622

593623
# struct llama_model_kv_override {
@@ -1511,6 +1541,14 @@ def llama_model_meta_count(model: llama_model_p, /) -> int:
15111541
...
15121542

15131543

1544+
# // Get sampling metadata key name. Returns nullptr if the key is invalid
1545+
# LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
1546+
@ctypes_function("llama_model_meta_key_str", [ctypes.c_int], ctypes.c_char_p)
1547+
def llama_model_meta_key_str(key: int, /) -> ctypes.c_char_p:
1548+
"""Get sampling metadata key name. Returns nullptr if the key is invalid"""
1549+
...
1550+
1551+
15141552
# // Get metadata key name by index
15151553
# LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
15161554
@ctypes_function(

0 commit comments

Comments
 (0)