Skip to content

Commit fb933d8

Browse files
jiminhamhelf-intel
authored andcommitted
Fix compile error for Gemma3 multimodal inputs (vllm-project#671)
Due to the latest changes from upstream, gemma3 is failing to compile on HPU vllm-project/vllm#27772 vllm-project/vllm#28842 -replace unfold to view/reshape -replace text embedding to avoid dynamic shape -remove merge_multimodal replacement since masked_scatter issue is fixed -enable back gemma3 model test --------- Signed-off-by: Jimin Ha <jimin.ha@intel.com>
1 parent 6fc04ba commit fb933d8

File tree

4 files changed

+80
-4
lines changed

4 files changed

+80
-4
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@ echo $VLLM_GAUDI_PREFIX
1313
# Gemma3 with image input
1414
run_gemma3_test() {
1515
echo "➡️ Testing gemma-3-4b-it..."
16-
#VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/gemma-3-4b-it.yaml"
16+
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/gemma-3-4b-it.yaml"
1717
echo "✅ Test with multimodal-support with gemma-3-4b-it passed."
1818
echo "➡️ Testing gemma-3-4b-it with multiple images(applying sliding_window)..."
19-
#VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm_multi.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/gemma-3-27b-it.yaml"
19+
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm_multi.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/gemma-3-27b-it.yaml"
2020
echo "✅ Test with multimodal-support with multiple images gemma-3-27b-it passed."
21-
#Test cases are commented because of PR27772
2221
}
2322

2423
# Basic model test

vllm_gaudi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ def register_ops():
2020
import vllm_gaudi.ops.hpu_gptq # noqa: F401
2121
import vllm_gaudi.ops.hpu_awq # noqa: F401
2222
import vllm_gaudi.ops.hpu_multihead_attn # noqa: F401
23+
import vllm_gaudi.ops.hpu_conv # noqa: F401
2324

2425

2526
def register_models():
26-
import vllm_gaudi.models.utils # noqa: F401
27+
import vllm_gaudi.models.interfaces # noqa: F401
2728
from .models import register_model
2829
register_model()

vllm_gaudi/models/interfaces.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from collections.abc import Callable
2+
import torch
3+
from torch import Tensor
4+
from vllm.model_executor.models.interfaces import SupportsMultiModal
5+
6+
7+
def _embed_text_input_ids(
8+
self,
9+
input_ids: Tensor,
10+
embed_input_ids: Callable[[Tensor], Tensor],
11+
*,
12+
is_multimodal: Tensor | None,
13+
handle_oov_mm_token: bool,
14+
) -> Tensor:
15+
if handle_oov_mm_token and is_multimodal is not None:
16+
is_text = ~is_multimodal
17+
18+
# Original implementation uses dynamic indexing.
19+
# Replacing it to use fixed shape for HPU and then fill in text position.
20+
'''
21+
text_embeds = embed_input_ids(input_ids[is_text])
22+
23+
return torch.empty(
24+
(input_ids.shape[0], text_embeds.shape[1]),
25+
dtype=text_embeds.dtype,
26+
device=text_embeds.device,
27+
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
28+
'''
29+
all_text_embeds = embed_input_ids(input_ids)
30+
result = torch.zeros_like(all_text_embeds)
31+
32+
return torch.where(
33+
is_text.unsqueeze(-1), # [batch, seq_len, 1]
34+
all_text_embeds, # [batch, seq_len, embed_dim]
35+
result # [batch, seq_len, embed_dim]
36+
)
37+
38+
return embed_input_ids(input_ids)
39+
40+
41+
SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids

vllm_gaudi/ops/hpu_conv.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from vllm.model_executor.layers.conv import Conv2dLayer
4+
5+
6+
@Conv2dLayer.register_oot
7+
class HPUConv2dLayer(Conv2dLayer):
8+
9+
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
10+
assert x.dim() == 4
11+
B, C, H, W = x.shape
12+
K1, K2 = self.kernel_size
13+
H, W = H // K1, W // K2
14+
15+
# TODO: HPU doesn't support unfold, implement with view,reshape.
16+
#x = x.unfold(2, K1, K1).unfold(3, K2, K2)
17+
#x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
18+
x = x.view(B, C, H, K1, W, K2)
19+
x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, self.input_size) # [B*H*W, C*K1*K2]
20+
21+
x = F.linear(
22+
x,
23+
self.weight.view(self.out_channels, self.input_size),
24+
self.bias,
25+
)
26+
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
27+
return x
28+
29+
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
30+
"""Expected input shape: (batch_size, in_channels, height, width)"""
31+
assert x.dim() == 4
32+
if self.enable_linear:
33+
return self._forward_mulmat(x)
34+
else:
35+
return self._forward_conv(x)

0 commit comments

Comments
 (0)