@@ -13,6 +13,7 @@ pub mod kv_overrides;
1313pub struct LlamaModelParams {
1414 pub ( crate ) params : llama_cpp_sys_2:: llama_model_params ,
1515 kv_overrides : Vec < llama_cpp_sys_2:: llama_model_kv_override > ,
16+ buft_overrides : Vec < llama_cpp_sys_2:: llama_model_tensor_buft_override > ,
1617}
1718
1819impl Debug for LlamaModelParams {
@@ -107,6 +108,48 @@ impl LlamaModelParams {
107108 }
108109}
109110
111+ impl LlamaModelParams {
112+ /// Adds buffer type overides to move all mixture-of-experts layers to CPU.
113+ pub fn add_cpu_moe_override ( self : Pin < & mut Self > ) {
114+ self . add_cpu_buft_override ( c"\\ .ffn_(up|down|gate)_(ch|)exps" ) ;
115+ }
116+
117+ /// Appends a buffer type override to the model parameters, to move layers matching pattern to CPU.
118+ /// It must be pinned as this creates a self-referential struct.
119+ pub fn add_cpu_buft_override ( mut self : Pin < & mut Self > , key : & CStr ) {
120+ let buft_override = self
121+ . buft_overrides
122+ . get_mut ( 0 )
123+ . expect ( "buft_overrides did not have a next allocated" ) ;
124+
125+ assert ! (
126+ buft_override. pattern. is_null( ) ,
127+ "last buft_override was not empty"
128+ ) ;
129+
130+ // There should be some way to do this without iterating over everything.
131+ for ( _i, & c) in key. to_bytes_with_nul ( ) . iter ( ) . enumerate ( ) {
132+ c_char:: try_from ( c) . expect ( "invalid character in key" ) ;
133+ }
134+
135+ buft_override. pattern = key. as_ptr ( ) ;
136+ buft_override. buft = unsafe { llama_cpp_sys_2:: ggml_backend_cpu_buffer_type ( ) } ;
137+
138+ // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
139+ self . params . tensor_buft_overrides = null ( ) ;
140+
141+ // push the next one to ensure we maintain the iterator invariant of ending with a 0
142+ self . buft_overrides
143+ . push ( llama_cpp_sys_2:: llama_model_tensor_buft_override {
144+ pattern : std:: ptr:: null ( ) ,
145+ buft : std:: ptr:: null_mut ( ) ,
146+ } ) ;
147+
148+ // set the pointer to the (potentially) new vector
149+ self . params . tensor_buft_overrides = self . buft_overrides . as_ptr ( ) ;
150+ }
151+ }
152+
110153impl LlamaModelParams {
111154 /// Get the number of layers to offload to the GPU.
112155 #[ must_use]
@@ -199,6 +242,10 @@ impl Default for LlamaModelParams {
199242 val_i64: 0 ,
200243 } ,
201244 } ] ,
245+ buft_overrides : vec ! [ llama_cpp_sys_2:: llama_model_tensor_buft_override {
246+ pattern: std:: ptr:: null( ) ,
247+ buft: std:: ptr:: null_mut( ) ,
248+ } ] ,
202249 }
203250 }
204251}
0 commit comments