@@ -172,7 +172,9 @@ def llama_free(ctx: llama_context_p):
172172# TODO: not great API - very likely to change
173173# Returns 0 on success
174174# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
175- def llama_model_quantize (fname_inp : bytes , fname_out : bytes , ftype : c_int , nthread : c_int ) -> c_int :
175+ def llama_model_quantize (
176+ fname_inp : bytes , fname_out : bytes , ftype : c_int , nthread : c_int
177+ ) -> c_int :
176178 return _lib .llama_model_quantize (fname_inp , fname_out , ftype , nthread )
177179
178180
@@ -187,7 +189,10 @@ def llama_model_quantize(fname_inp: bytes, fname_out: bytes, ftype: c_int, nthre
187189# will be applied on top of the previous one
188190# Returns 0 on success
189191def llama_apply_lora_from_file (
190- ctx : llama_context_p , path_lora : ctypes .c_char_p , path_base_model : ctypes .c_char_p , n_threads : c_int
192+ ctx : llama_context_p ,
193+ path_lora : ctypes .c_char_p ,
194+ path_base_model : ctypes .c_char_p ,
195+ n_threads : c_int ,
191196) -> c_int :
192197 return _lib .llama_apply_lora_from_file (ctx , path_lora , path_base_model , n_threads )
193198
@@ -235,6 +240,36 @@ def llama_set_kv_cache(
235240_lib .llama_set_kv_cache .restype = None
236241
237242
243+ # Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
244+ def llama_get_state_size (ctx : llama_context_p ) -> c_size_t :
245+ return _lib .llama_get_state_size (ctx )
246+
247+
248+ _lib .llama_get_state_size .argtypes = [llama_context_p ]
249+ _lib .llama_get_state_size .restype = c_size_t
250+
251+
252+ # Copies the state to the specified destination address.
253+ # Destination needs to have allocated enough memory.
254+ # Returns the number of bytes copied
255+ def llama_copy_state_data (ctx : llama_context_p , dest ) -> c_size_t :
256+ return _lib .llama_copy_state_data (ctx , dest )
257+
258+
259+ _lib .llama_copy_state_data .argtypes = [llama_context_p , POINTER (c_uint8 )]
260+ _lib .llama_copy_state_data .restype = c_size_t
261+
262+
263+ # Set the state reading from the specified address
264+ # Returns the number of bytes read
265+ def llama_set_state_data (ctx : llama_context_p , src ) -> c_size_t :
266+ return _lib .llama_set_state_data (ctx , src )
267+
268+
269+ _lib .llama_set_state_data .argtypes = [llama_context_p , POINTER (c_uint8 )]
270+ _lib .llama_set_state_data .restype = c_size_t
271+
272+
238273# Run the llama inference to obtain the logits and probabilities for the next token.
239274# tokens + n_tokens is the provided batch of new tokens to process
240275# n_past is the number of tokens to use from previous eval calls
0 commit comments