@@ -5,6 +5,16 @@ enum LlamaError: Error {
55 case couldNotInitializeContext
66}
77
8+ struct LlamaRuntimeOptions {
9+ var contextLength : Int32
10+ var nGpuLayers : Int32
11+ var seed : UInt32
12+ var temperature : Float
13+ var topP : Float
14+ var topK : Int32
15+ var flashAttention : Bool
16+ }
17+
818func llama_batch_clear( _ batch: inout llama_batch ) {
919 batch. n_tokens = 0
1020}
@@ -25,9 +35,10 @@ actor LlamaContext {
2535 private var model : OpaquePointer
2636 private var context : OpaquePointer
2737 private var vocab : OpaquePointer
28- private var sampling : UnsafeMutablePointer < llama_sampler >
38+ private var sampling : UnsafeMutablePointer < llama_sampler > ?
2939 private var batch : llama_batch
3040 private var tokens_list : [ llama_token ]
41+ private var runtimeOptions : LlamaRuntimeOptions
3142 var is_done : Bool = false
3243
3344 /// This variable is used to store temporarily invalid cchars
@@ -38,34 +49,80 @@ actor LlamaContext {
3849
3950 var n_decode : Int32 = 0
4051
41- init ( model: OpaquePointer , context: OpaquePointer ) {
52+ init ( model: OpaquePointer , context: OpaquePointer , options : LlamaRuntimeOptions ) {
4253 self . model = model
4354 self . context = context
4455 self . tokens_list = [ ]
45- self . batch = llama_batch_init ( 512 , 0 , 1 )
56+ self . batch = llama_batch_init ( max ( Int32 ( 512 ) , options . contextLength ) , 0 , 1 )
4657 self . temporary_invalid_cchars = [ ]
47- let sparams = llama_sampler_chain_default_params ( )
48- self . sampling = llama_sampler_chain_init ( sparams)
49- llama_sampler_chain_add ( self . sampling, llama_sampler_init_temp ( 0.4 ) )
50- llama_sampler_chain_add ( self . sampling, llama_sampler_init_dist ( 1234 ) )
58+ self . runtimeOptions = options
59+ self . n_len = options. contextLength
5160 vocab = llama_model_get_vocab ( model)
61+
62+ let chainParams = llama_sampler_chain_default_params ( )
63+ let initialChain = llama_sampler_chain_init ( chainParams)
64+
65+ if options. topK > 0 {
66+ llama_sampler_chain_add ( initialChain, llama_sampler_init_top_k ( options. topK) )
67+ }
68+
69+ let clampedTopP = max ( 0.0 , min ( Double ( options. topP) , 1.0 ) )
70+ llama_sampler_chain_add ( initialChain, llama_sampler_init_top_p ( Float ( clampedTopP) , 1 ) )
71+
72+ let clampedTemp = max ( 0.0 , Double ( options. temperature) )
73+ llama_sampler_chain_add ( initialChain, llama_sampler_init_temp ( Float ( clampedTemp) ) )
74+
75+ let seed = options. seed == 0 ? UInt32 . max : options. seed
76+ llama_sampler_chain_add ( initialChain, llama_sampler_init_dist ( seed) )
77+
78+ sampling = initialChain
5279 }
5380
5481 deinit {
55- llama_sampler_free ( sampling)
82+ if let sampling {
83+ llama_sampler_free ( sampling)
84+ }
5685 llama_batch_free ( batch)
5786 llama_model_free ( model)
5887 llama_free ( context)
5988 llama_backend_free ( )
6089 }
6190
62- static func create_context( path: String ) throws -> LlamaContext {
91+ private func rebuildSamplerChain( ) {
92+ let chainParams = llama_sampler_chain_default_params ( )
93+ let newChain = llama_sampler_chain_init ( chainParams)
94+
95+ if runtimeOptions. topK > 0 {
96+ llama_sampler_chain_add ( newChain, llama_sampler_init_top_k ( runtimeOptions. topK) )
97+ }
98+
99+ let clampedTopP = max ( 0.0 , min ( runtimeOptions. topP, 1.0 ) )
100+ llama_sampler_chain_add ( newChain, llama_sampler_init_top_p ( clampedTopP, 1 ) )
101+
102+ let clampedTemp = max ( 0.0 , Double ( runtimeOptions. temperature) )
103+ llama_sampler_chain_add ( newChain, llama_sampler_init_temp ( Float ( clampedTemp) ) )
104+
105+ let seed = runtimeOptions. seed == 0 ? UInt32 . max : runtimeOptions. seed
106+ llama_sampler_chain_add ( newChain, llama_sampler_init_dist ( seed) )
107+
108+ if let sampling {
109+ llama_sampler_free ( sampling)
110+ }
111+
112+ sampling = newChain
113+ }
114+
115+ static func create_context( path: String , options: LlamaRuntimeOptions ) throws -> LlamaContext {
63116 llama_backend_init ( )
64117 var model_params = llama_model_default_params ( )
65118
66119#if targetEnvironment(simulator)
67120 model_params. n_gpu_layers = 0
68121 print ( " Running on simulator, force use n_gpu_layers = 0 " )
122+ #else
123+ if options. nGpuLayers >= 0 {
124+ model_params. n_gpu_layers = options. nGpuLayers
125+ }
69126#endif
70127 let model = llama_model_load_from_file ( path, model_params)
71128 guard let model else {
@@ -77,17 +134,24 @@ actor LlamaContext {
77134 print ( " Using \( n_threads) threads " )
78135
79136 var ctx_params = llama_context_default_params ( )
80- ctx_params. n_ctx = 2048
137+ ctx_params. n_ctx = UInt32 ( options . contextLength )
81138 ctx_params. n_threads = Int32 ( n_threads)
82139 ctx_params. n_threads_batch = Int32 ( n_threads)
140+ ctx_params. flash_attn_type = options. flashAttention ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED
83141
84142 let context = llama_init_from_model ( model, ctx_params)
85143 guard let context else {
86144 print ( " Could not load context! " )
87145 throw LlamaError . couldNotInitializeContext
88146 }
89147
90- return LlamaContext ( model: model, context: context)
148+ return LlamaContext ( model: model, context: context, options: options)
149+ }
150+
151+ func updateSampler( options: LlamaRuntimeOptions ) {
152+ runtimeOptions = options
153+ n_len = options. contextLength
154+ rebuildSamplerChain ( )
91155 }
92156
93157 func model_info( ) -> String {
@@ -151,6 +215,10 @@ actor LlamaContext {
151215 func completion_loop( ) -> String {
152216 var new_token_id : llama_token = 0
153217
218+ guard let sampling else {
219+ return " "
220+ }
221+
154222 new_token_id = llama_sampler_sample ( sampling, context, batch. n_tokens - 1 )
155223
156224 if llama_vocab_is_eog ( vocab, new_token_id) || n_cur == n_len {
0 commit comments