Skip to content

Commit a9705a2

Browse files
authored
[Model][QwenVL] Replace torch.repeat_interleave with faster np.repeat (#28964)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent 64192d5 commit a9705a2

File tree

3 files changed

+18
-23
lines changed

3 files changed

+18
-23
lines changed

tests/models/multimodal/generation/test_qwen2_vl.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,7 @@ def get_image_embeds(model):
128128
visual = model.visual
129129

130130
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
131-
image_grid_thw_on_device = image_grid_thw.to(
132-
visual.device, dtype=torch.int64
133-
)
134-
return visual(
135-
pixel_values_on_device, grid_thw=image_grid_thw_on_device
136-
).cpu()
131+
return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
137132

138133
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
139134

@@ -217,12 +212,7 @@ def get_image_embeds(model):
217212
visual = model.visual
218213

219214
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
220-
video_grid_thw_on_device = video_grid_thw.to(
221-
visual.device, dtype=torch.int64
222-
)
223-
return visual(
224-
pixel_values_on_device, grid_thw=video_grid_thw_on_device
225-
).cpu()
215+
return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
226216

227217
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
228218

vllm/model_executor/models/qwen2_vl.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from functools import partial
3030
from typing import Annotated, Any, Literal, TypeAlias
3131

32+
import numpy as np
3233
import torch
3334
import torch.nn as nn
3435
import torch.nn.functional as F
@@ -751,25 +752,27 @@ def forward(
751752

752753
if isinstance(grid_thw, list):
753754
grid_thw_list = grid_thw
754-
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
755+
grid_thw = np.array(grid_thw, dtype=np.int32)
755756
else:
756757
grid_thw_list = grid_thw.tolist()
758+
grid_thw = grid_thw.numpy()
757759

758760
# compute position embedding
759761
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
760762

761763
# compute cu_seqlens
762-
cu_seqlens = torch.repeat_interleave(
763-
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
764-
).cumsum(dim=0, dtype=torch.int32)
765-
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
766-
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
764+
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
765+
axis=0, dtype=np.int32
766+
)
767+
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
768+
cu_seqlens = torch.from_numpy(cu_seqlens)
767769

768770
# transformers
769771
x = x.unsqueeze(1)
770772

771773
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
772774
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
775+
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
773776
for blk in self.blocks:
774777
x = blk(
775778
x,

vllm/model_executor/models/qwen3_vl.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,20 @@ def forward(
553553

554554
if isinstance(grid_thw, list):
555555
grid_thw_list = grid_thw
556-
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
556+
grid_thw = np.array(grid_thw, dtype=np.int32)
557557
else:
558558
grid_thw_list = grid_thw.tolist()
559+
grid_thw = grid_thw.numpy()
559560

560561
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
561562
hidden_states = hidden_states + pos_embeds
562563
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
563564

564-
cu_seqlens = torch.repeat_interleave(
565-
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
566-
).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
567-
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
565+
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
566+
axis=0, dtype=np.int32
567+
)
568+
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
569+
cu_seqlens = torch.from_numpy(cu_seqlens)
568570

569571
hidden_states = hidden_states.unsqueeze(1)
570572
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)

0 commit comments

Comments
 (0)