Skip to content

Commit 095775e

Browse files
danieldkdrbh
andauthored
launcher: correctly get the head dimension for VLMs (#3116)
* launcher: correctly get the head dimension for VLMs For most (?) VLMs, the head dimension is in the `text_config` configuration section. However, since we only queried the top-level `head_dim` (which typically doesn't exist in VLMs), we would never use flashinfer. This change adds a method that gets the head dimension from the top-level `Config` struct or `text_config` when that fails. * fix: bump org name in gemma3 test --------- Co-authored-by: drbh <david.richard.holtz@gmail.com>
1 parent 0b3e3db commit 095775e

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

integration-tests/models/test_flash_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
@pytest.fixture(scope="module")
55
def flash_gemma3_handle(launcher):
6-
with launcher("gg-hf-g/gemma-3-4b-it", num_shard=2) as handle:
6+
with launcher("google/gemma-3-4b-it", num_shard=2) as handle:
77
yield handle
88

99

launcher/src/main.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
152152
"flashdecoding"
153153
};
154154

155-
match config.head_dim {
155+
match config.get_head_dim() {
156156
Some(h) if h == 64 || h == 128 || h == 256 => {
157157
if lora_adapters.is_some() && prefix_caching.is_none() {
158158
tracing::info!("Disabling prefix caching because of lora adapters");
@@ -214,6 +214,7 @@ struct RawConfig {
214214
num_key_value_heads: Option<usize>,
215215
num_hidden_layers: Option<usize>,
216216
head_dim: Option<usize>,
217+
text_config: Option<TextConfig>,
217218
vision_config: Option<VisionConfig>,
218219
is_encoder_decoder: Option<bool>,
219220
#[serde(rename = "num_experts_per_tok")]
@@ -233,6 +234,11 @@ struct QuantizationConfig {
233234
#[derive(Debug, Deserialize)]
234235
struct VisionConfig {}
235236

237+
#[derive(Debug, Deserialize)]
238+
struct TextConfig {
239+
head_dim: Option<usize>,
240+
}
241+
236242
#[derive(Debug, Deserialize)]
237243
struct Config {
238244
max_position_embeddings: Option<usize>,
@@ -244,6 +250,7 @@ struct Config {
244250
intermediate_size: Option<usize>,
245251
hidden_size: Option<usize>,
246252
model_type: Option<String>,
253+
text_config: Option<TextConfig>,
247254
vision_config: Option<VisionConfig>,
248255
is_encoder_decoder: bool,
249256
num_experts_per_token: usize,
@@ -253,6 +260,14 @@ struct Config {
253260
}
254261

255262
impl Config {
263+
fn get_head_dim(&self) -> Option<usize> {
264+
self.head_dim.or_else(|| {
265+
self.text_config
266+
.as_ref()
267+
.and_then(|text_config| text_config.head_dim)
268+
})
269+
}
270+
256271
fn flop(&self) -> Option<u64> {
257272
if self.vision_config.is_some() {
258273
// VLM are much harder to predict and VRAM requirements
@@ -261,7 +276,7 @@ impl Config {
261276
}
262277
let num_heads = self.num_heads? as u64;
263278
let num_kv_heads = self.num_kv_heads? as u64;
264-
let head_dim = self.head_dim? as u64;
279+
let head_dim = self.get_head_dim()? as u64;
265280
let hidden_size = self.hidden_size? as u64;
266281
let intermediate_size = (self.intermediate_size?
267282
* (self.num_experts_per_token + self.num_shared_experts))
@@ -289,7 +304,7 @@ impl Config {
289304
}
290305
// 2 for key and values
291306
// 2 for f16 dtype?
292-
Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?)
307+
Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)
293308
}
294309

295310
fn mlp_vram_per_tok(&self) -> Option<usize> {
@@ -310,8 +325,8 @@ impl Config {
310325
}
311326

312327
fn model_vram(&self) -> Option<usize> {
313-
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?;
314-
let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?;
328+
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;
329+
let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;
315330
// gate + up + down = 3
316331
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
317332
let layer_vram = mlp_vram + attn_vram + o_vram;
@@ -349,6 +364,7 @@ impl From<RawConfig> for Config {
349364
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
350365
let intermediate_size = other.intermediate_size;
351366
let model_type = other.model_type;
367+
let text_config = other.text_config;
352368
let vision_config = other.vision_config;
353369
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
354370
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
@@ -360,6 +376,7 @@ impl From<RawConfig> for Config {
360376
quantize,
361377
head_dim,
362378
model_type,
379+
text_config,
363380
vision_config,
364381
is_encoder_decoder,
365382
hidden_size,

0 commit comments

Comments
 (0)