Skip to content

Commit c261237

Browse files
[Model] Add Gemma3 GGUF multimodal support (#27772)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 49a986e commit c261237

File tree

14 files changed

+751
-85
lines changed

14 files changed

+751
-85
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
3030
partial-json-parser # used for parsing partial JSON outputs
3131
pyzmq >= 25.0.0
3232
msgspec
33-
gguf >= 0.13.0
33+
gguf >= 0.17.0
3434
mistral_common[image] >= 1.8.5
3535
opencv-python-headless >= 4.11.0 # required for video IO
3636
pyyaml
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Literal, NamedTuple
5+
6+
import pytest
7+
from huggingface_hub import hf_hub_download
8+
from pytest import MarkDecorator
9+
10+
from tests.quantization.utils import is_quant_method_supported
11+
from vllm.assets.image import ImageAsset
12+
from vllm.utils.torch_utils import set_default_torch_num_threads
13+
14+
from ....conftest import PromptImageInput, VllmRunner
15+
from ...utils import check_logprobs_close
16+
17+
18+
class GGUFMMTestConfig(NamedTuple):
19+
original_model: str
20+
gguf_repo: str
21+
gguf_backbone: str
22+
gguf_mmproj: str
23+
prompt: list[str]
24+
mm_data: dict[Literal["images"], PromptImageInput]
25+
max_model_len: int = 4096
26+
marks: list[MarkDecorator] = []
27+
28+
@property
29+
def gguf_model(self):
30+
hf_hub_download(self.gguf_repo, filename=self.gguf_mmproj)
31+
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
32+
33+
34+
GEMMA3_CONFIG = GGUFMMTestConfig(
35+
original_model="google/gemma-3-4b-it",
36+
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
37+
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
38+
gguf_mmproj="mmproj-model-f16-4B.gguf",
39+
prompt=["<start_of_image>Describe this image in detail:"],
40+
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
41+
marks=[pytest.mark.core_model],
42+
)
43+
44+
MODELS_TO_TEST = [GEMMA3_CONFIG]
45+
46+
47+
def run_multimodal_gguf_test(
48+
vllm_runner: type[VllmRunner],
49+
model: GGUFMMTestConfig,
50+
dtype: str,
51+
max_tokens: int,
52+
num_logprobs: int,
53+
):
54+
# Run gguf model.
55+
with (
56+
set_default_torch_num_threads(1),
57+
vllm_runner(
58+
model_name=model.gguf_model,
59+
enforce_eager=True,
60+
tokenizer_name=model.original_model,
61+
dtype=dtype,
62+
max_model_len=model.max_model_len,
63+
) as gguf_model,
64+
):
65+
gguf_outputs = gguf_model.generate_greedy_logprobs(
66+
prompts=model.prompt,
67+
max_tokens=max_tokens,
68+
num_logprobs=num_logprobs,
69+
**model.mm_data,
70+
)
71+
72+
# Run unquantized model.
73+
with vllm_runner(
74+
model_name=model.original_model,
75+
enforce_eager=True, # faster tests
76+
dtype=dtype,
77+
max_model_len=model.max_model_len,
78+
) as original_model:
79+
original_outputs = original_model.generate_greedy_logprobs(
80+
prompts=model.prompt,
81+
max_tokens=max_tokens,
82+
num_logprobs=num_logprobs,
83+
**model.mm_data,
84+
)
85+
86+
check_logprobs_close(
87+
outputs_0_lst=original_outputs,
88+
outputs_1_lst=gguf_outputs,
89+
name_0="original",
90+
name_1="gguf",
91+
)
92+
93+
94+
@pytest.mark.skipif(
95+
not is_quant_method_supported("gguf"),
96+
reason="gguf is not supported on this GPU type.",
97+
)
98+
@pytest.mark.parametrize(
99+
"model",
100+
[
101+
pytest.param(test_config, marks=test_config.marks)
102+
for test_config in MODELS_TO_TEST
103+
],
104+
)
105+
@pytest.mark.parametrize("dtype", ["bfloat16"])
106+
@pytest.mark.parametrize("max_tokens", [32])
107+
@pytest.mark.parametrize("num_logprobs", [10])
108+
def test_models(
109+
vllm_runner: type[VllmRunner],
110+
model: GGUFMMTestConfig,
111+
dtype: str,
112+
max_tokens: int,
113+
num_logprobs: int,
114+
) -> None:
115+
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)

tests/models/quantization/test_gguf.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ def gguf_model(self):
7878
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
7979
)
8080

81+
GEMMA3_CONFIG = GGUFTestConfig(
82+
original_model="google/gemma-3-270m-it",
83+
gguf_repo="ggml-org/gemma-3-270m-it-qat-GGUF",
84+
gguf_filename="gemma-3-270m-it-qat-Q4_0.gguf",
85+
)
86+
8187
MODELS = [
8288
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
8389
QWEN2_CONFIG,
8490
PHI3_CONFIG,
8591
GPT2_CONFIG,
8692
STABLELM_CONFIG,
8793
DOLPHIN_CONFIG,
94+
GEMMA3_CONFIG,
8895
# STARCODER_CONFIG, # broken
8996
]
9097

@@ -148,7 +155,7 @@ def check_model_outputs(
148155
"model",
149156
[pytest.param(test_config, marks=test_config.marks) for test_config in MODELS],
150157
)
151-
@pytest.mark.parametrize("dtype", ["half"])
158+
@pytest.mark.parametrize("dtype", ["bfloat16"])
152159
@pytest.mark.parametrize("max_tokens", [32])
153160
@pytest.mark.parametrize("num_logprobs", [5])
154161
@pytest.mark.parametrize("tp_size", [1])

vllm/config/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@
3333
try_get_generation_config,
3434
try_get_safetensors_metadata,
3535
try_get_tokenizer_config,
36+
uses_custom_attention_masks,
3637
uses_mrope,
3738
)
39+
from vllm.transformers_utils.gguf_utils import (
40+
maybe_patch_hf_config_from_gguf,
41+
)
3842
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
39-
from vllm.transformers_utils.utils import maybe_model_redirect
43+
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
4044
from vllm.utils.import_utils import LazyLoader
4145
from vllm.utils.torch_utils import common_broadcastable_dtype
4246

@@ -450,6 +454,12 @@ def __post_init__(
450454
self.model = maybe_model_redirect(self.model)
451455
# The tokenizer is consistent with the model by default.
452456
if self.tokenizer is None:
457+
if check_gguf_file(self.model):
458+
raise ValueError(
459+
"Using a tokenizer is mandatory when loading a GGUF model. "
460+
"Please specify the tokenizer path or name using the "
461+
"--tokenizer argument."
462+
)
453463
self.tokenizer = self.model
454464
if self.tokenizer_revision is None:
455465
self.tokenizer_revision = self.revision
@@ -508,6 +518,10 @@ def __post_init__(
508518
hf_overrides_kw=hf_overrides_kw,
509519
hf_overrides_fn=hf_overrides_fn,
510520
)
521+
hf_config = maybe_patch_hf_config_from_gguf(
522+
self.model,
523+
hf_config,
524+
)
511525

512526
self.hf_config = hf_config
513527
if dict_overrides:
@@ -1605,6 +1619,10 @@ def uses_alibi(self) -> bool:
16051619
def uses_mrope(self) -> bool:
16061620
return uses_mrope(self.hf_config)
16071621

1622+
@property
1623+
def uses_custom_attention_masks(self) -> bool:
1624+
return uses_custom_attention_masks(self.hf_config)
1625+
16081626
@property
16091627
def is_multimodal_model(self) -> bool:
16101628
return self.multimodal_config is not None

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from collections.abc import Callable
4+
from collections.abc import Callable, Mapping
5+
from types import MappingProxyType
56
from typing import Any, Optional
67

78
import gguf
@@ -26,7 +27,11 @@
2627
QuantizationConfig,
2728
QuantizeMethodBase,
2829
)
29-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
30+
from vllm.model_executor.layers.vocab_parallel_embedding import (
31+
UnquantizedEmbeddingMethod,
32+
VocabParallelEmbedding,
33+
)
34+
from vllm.model_executor.models.utils import WeightsMapper
3035
from vllm.model_executor.utils import set_weight_attrs
3136
from vllm.utils.torch_utils import direct_register_custom_op
3237

@@ -65,18 +70,70 @@ def get_quant_method(
6570
self, layer: torch.nn.Module, prefix: str
6671
) -> Optional["QuantizeMethodBase"]:
6772
if isinstance(layer, LinearBase):
68-
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
73+
if is_layer_skipped_gguf(
74+
prefix, self.unquantized_modules, self.packed_modules_mapping
75+
):
6976
return UnquantizedLinearMethod()
7077
return GGUFLinearMethod(self)
7178
elif isinstance(layer, VocabParallelEmbedding):
79+
if is_layer_skipped_gguf(
80+
prefix, self.unquantized_modules, self.packed_modules_mapping
81+
):
82+
return UnquantizedEmbeddingMethod()
7283
return GGUFEmbeddingMethod(self)
7384
elif isinstance(layer, FusedMoE):
7485
return GGUFMoEMethod(self, layer.moe_config)
7586
return None
7687

88+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
89+
"""
90+
Interface for models to update module names referenced in
91+
quantization configs in order to reflect the vllm model structure
92+
93+
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
94+
structure of the qconfig) to vllm model structure
95+
"""
96+
if self.unquantized_modules is not None:
97+
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
98+
self.unquantized_modules
99+
)
100+
101+
102+
def is_layer_skipped_gguf(
103+
prefix: str,
104+
unquantized_modules: list[str],
105+
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
106+
):
107+
# Fused layers like gate_up_proj or qkv_proj will not be fused
108+
# in the safetensors checkpoint. So, we convert the name
109+
# from the fused version to unfused + check to make sure that
110+
# each shard of the fused layer has the same scheme.
111+
proj_name = prefix.split(".")[-1]
112+
if proj_name in fused_mapping:
113+
shard_prefixes = [
114+
prefix.replace(proj_name, shard_proj_name)
115+
for shard_proj_name in fused_mapping[proj_name]
116+
]
117+
118+
is_skipped = None
119+
for shard_prefix in shard_prefixes:
120+
is_shard_skipped = any(
121+
shard_prefix in module_name for module_name in unquantized_modules
122+
)
123+
124+
if is_skipped is None:
125+
is_skipped = is_shard_skipped
126+
elif is_shard_skipped != is_skipped:
127+
raise ValueError(
128+
f"Detected some but not all shards of {prefix} "
129+
"are quantized. All shards of fused layers "
130+
"to have the same precision."
131+
)
132+
else:
133+
is_skipped = any(module_name in prefix for module_name in unquantized_modules)
77134

78-
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
79-
return any(module_name in prefix for module_name in unquantized_modules)
135+
assert is_skipped is not None
136+
return is_skipped
80137

81138

82139
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}

0 commit comments

Comments
 (0)