Skip to content

Commit 83085b7

Browse files
committed
Add support for adding tensor buffer type overrides
1 parent cb5060a commit 83085b7

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

examples/simple/src/main.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct Args {
4848
#[cfg(any(feature = "cuda", feature = "vulkan"))]
4949
#[clap(long)]
5050
disable_gpu: bool,
51+
#[cfg(any(feature = "cuda", feature = "vulkan"))]
52+
#[arg(long, help = "Keep MoE layers on CPU")]
53+
cmoe: bool,
5154
#[arg(short = 's', long, help = "RNG seed (default: 1234)")]
5255
seed: Option<u32>,
5356
#[arg(
@@ -129,6 +132,8 @@ fn main() -> Result<()> {
129132
file,
130133
#[cfg(any(feature = "cuda", feature = "vulkan"))]
131134
disable_gpu,
135+
#[cfg(any(feature = "cuda", feature = "vulkan"))]
136+
cmoe,
132137
key_value_overrides,
133138
seed,
134139
threads,
@@ -176,6 +181,13 @@ fn main() -> Result<()> {
176181
model_params.as_mut().append_kv_override(k.as_c_str(), *v);
177182
}
178183

184+
#[cfg(any(feature = "cuda", feature = "vulkan"))]
185+
{
186+
if !disable_gpu && cmoe {
187+
model_params.as_mut().add_cpu_moe_override();
188+
}
189+
}
190+
179191
let model_path = model
180192
.get_or_load()
181193
.with_context(|| "failed to get model from args")?;

llama-cpp-2/src/model/params.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub mod kv_overrides;
1313
pub struct LlamaModelParams {
1414
pub(crate) params: llama_cpp_sys_2::llama_model_params,
1515
kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
16+
buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
1617
}
1718

1819
impl Debug for LlamaModelParams {
@@ -107,6 +108,51 @@ impl LlamaModelParams {
107108
}
108109
}
109110

111+
impl LlamaModelParams {
112+
113+
/// Adds buffer type overides to move all mixture-of-experts layers to CPU.
114+
pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
115+
self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
116+
}
117+
118+
/// Appends a buffer type override to the model parameters, to move layers matching pattern to CPU.
119+
/// It must be pinned as this creates a self-referential struct.
120+
pub fn add_cpu_buft_override(
121+
mut self: Pin<&mut Self>,
122+
key: &CStr,
123+
) {
124+
let buft_override = self
125+
.buft_overrides
126+
.get_mut(0)
127+
.expect("buft_overrides did not have a next allocated");
128+
129+
assert!(buft_override.pattern.is_null(), "last buft_override was not empty");
130+
131+
// There should be some way to do this without iterating over everything.
132+
for (_i, &c) in key.to_bytes_with_nul().iter().enumerate() {
133+
c_char::try_from(c).expect("invalid character in key");
134+
}
135+
136+
buft_override.pattern = key.as_ptr();
137+
buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
138+
139+
// set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
140+
self.params.kv_overrides = null();
141+
142+
// push the next one to ensure we maintain the iterator invariant of ending with a 0
143+
self.buft_overrides
144+
.push(llama_cpp_sys_2::llama_model_tensor_buft_override {
145+
pattern: std::ptr::null(),
146+
buft: std::ptr::null_mut(),
147+
});
148+
149+
// set the pointer to the (potentially) new vector
150+
self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
151+
152+
eprintln!("saved ptr: {:?}", self.params.tensor_buft_overrides);
153+
}
154+
}
155+
110156
impl LlamaModelParams {
111157
/// Get the number of layers to offload to the GPU.
112158
#[must_use]
@@ -199,6 +245,10 @@ impl Default for LlamaModelParams {
199245
val_i64: 0,
200246
},
201247
}],
248+
buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
249+
pattern: std::ptr::null(),
250+
buft: std::ptr::null_mut(),
251+
}],
202252
}
203253
}
204254
}

0 commit comments

Comments
 (0)