@@ -13,6 +13,7 @@ pub mod kv_overrides;
1313pub struct LlamaModelParams {
1414 pub ( crate ) params : llama_cpp_sys_2:: llama_model_params ,
1515 kv_overrides : Vec < llama_cpp_sys_2:: llama_model_kv_override > ,
16+ buft_overrides : Vec < llama_cpp_sys_2:: llama_model_tensor_buft_override > ,
1617}
1718
1819impl Debug for LlamaModelParams {
@@ -107,6 +108,51 @@ impl LlamaModelParams {
107108 }
108109}
109110
111+ impl LlamaModelParams {
112+
113+ /// Adds buffer type overides to move all mixture-of-experts layers to CPU.
114+ pub fn add_cpu_moe_override ( self : Pin < & mut Self > ) {
115+ self . add_cpu_buft_override ( c"\\ .ffn_(up|down|gate)_(ch|)exps" ) ;
116+ }
117+
118+ /// Appends a buffer type override to the model parameters, to move layers matching pattern to CPU.
119+ /// It must be pinned as this creates a self-referential struct.
120+ pub fn add_cpu_buft_override (
121+ mut self : Pin < & mut Self > ,
122+ key : & CStr ,
123+ ) {
124+ let buft_override = self
125+ . buft_overrides
126+ . get_mut ( 0 )
127+ . expect ( "buft_overrides did not have a next allocated" ) ;
128+
129+ assert ! ( buft_override. pattern. is_null( ) , "last buft_override was not empty" ) ;
130+
131+ // There should be some way to do this without iterating over everything.
132+ for ( _i, & c) in key. to_bytes_with_nul ( ) . iter ( ) . enumerate ( ) {
133+ c_char:: try_from ( c) . expect ( "invalid character in key" ) ;
134+ }
135+
136+ buft_override. pattern = key. as_ptr ( ) ;
137+ buft_override. buft = unsafe { llama_cpp_sys_2:: ggml_backend_cpu_buffer_type ( ) } ;
138+
139+ // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
140+ self . params . kv_overrides = null ( ) ;
141+
142+ // push the next one to ensure we maintain the iterator invariant of ending with a 0
143+ self . buft_overrides
144+ . push ( llama_cpp_sys_2:: llama_model_tensor_buft_override {
145+ pattern : std:: ptr:: null ( ) ,
146+ buft : std:: ptr:: null_mut ( ) ,
147+ } ) ;
148+
149+ // set the pointer to the (potentially) new vector
150+ self . params . tensor_buft_overrides = self . buft_overrides . as_ptr ( ) ;
151+
152+ eprintln ! ( "saved ptr: {:?}" , self . params. tensor_buft_overrides) ;
153+ }
154+ }
155+
110156impl LlamaModelParams {
111157 /// Get the number of layers to offload to the GPU.
112158 #[ must_use]
@@ -199,6 +245,10 @@ impl Default for LlamaModelParams {
199245 val_i64: 0 ,
200246 } ,
201247 } ] ,
248+ buft_overrides : vec ! [ llama_cpp_sys_2:: llama_model_tensor_buft_override {
249+ pattern: std:: ptr:: null( ) ,
250+ buft: std:: ptr:: null_mut( ) ,
251+ } ] ,
202252 }
203253 }
204254}
0 commit comments