Skip to content

Commit c607174

Browse files
authored
Fix mask passed to flashinfer (#3324)
Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill).
1 parent 4f067c2 commit c607174

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

server/text_generation_server/layers/attention/flashinfer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22
from contextvars import ContextVar
33
from contextlib import contextmanager
4+
import math
45

56
import flashinfer
67
import torch
@@ -20,6 +21,20 @@
2021
workspace: Optional[torch.Tensor] = None
2122

2223

24+
def unpad_2d_mask(
25+
attention_mask: torch.Tensor, seq_lengths: torch.Tensor
26+
) -> torch.Tensor:
27+
# Like torch unpad_sequence, but for 2D masks.
28+
unpadded_tensors = []
29+
for i, length in enumerate(seq_lengths):
30+
unpadded_matrix = attention_mask[i, :length, :length]
31+
unpadded_tensors.append(unpadded_matrix.flatten())
32+
33+
packed_tensor = torch.cat(unpadded_tensors)
34+
35+
return packed_tensor
36+
37+
2338
def get_workspace(device):
2439
"""Get shared flashinfer workspace."""
2540
global workspace
@@ -83,6 +98,15 @@ def use_prefill_with_paged_kv_state(
8398
last_page_len += 1
8499

85100
token = prefill_with_paged_kv_state.set(state)
101+
102+
# Attention masks are padded, unpad.
103+
if custom_mask is not None:
104+
bs = input_lengths.shape[0]
105+
seq_len = math.isqrt(custom_mask.numel() // bs)
106+
custom_mask = unpad_2d_mask(
107+
custom_mask.reshape(bs, seq_len, seq_len), input_lengths
108+
)
109+
86110
try:
87111
state.plan(
88112
qo_indptr=cu_seqlens,

0 commit comments

Comments
 (0)