Skip to content

Commit 429dcd9

Browse files
authored
[gaudi] Gemma3 sliding window support (#3280)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 9f38d93 commit 429dcd9

File tree

12 files changed

+389
-98
lines changed

12 files changed

+389
-98
lines changed

backends/gaudi/server/text_generation_server/layers/attention/common.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from typing import Optional, List, Dict
44
import collections
5+
import torch.nn.functional as F
56

67
_TYPE_CACHE = {}
78

@@ -15,6 +16,12 @@ class HPUPagedAttentionMetadata:
1516
block_usage: Optional[torch.Tensor]
1617
block_groups: Optional[torch.Tensor]
1718
attn_bias: Optional[torch.Tensor]
19+
slots_in_window_mask: Optional[torch.Tensor] = None
20+
block_list_in_window: Optional[torch.Tensor] = None
21+
block_mapping_in_window: Optional[torch.Tensor] = None
22+
block_usage_in_window: Optional[torch.Tensor] = None
23+
block_groups_in_window: Optional[torch.Tensor] = None
24+
attn_bias_in_window: Optional[torch.Tensor] = None
1825

1926

2027
def subtuple(
@@ -67,6 +74,12 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
6774
"block_usage",
6875
"block_groups",
6976
"attn_bias",
77+
"slots_in_window_mask",
78+
"block_list_in_window",
79+
"block_mapping_in_window",
80+
"block_usage_in_window",
81+
"block_groups_in_window",
82+
"attn_bias_in_window",
7083
],
7184
)
7285
return attention_metadata
@@ -75,6 +88,7 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
7588
@dataclass
7689
class Seqlen:
7790
input_lengths: torch.Tensor
91+
attn_mask: Optional[torch.Tensor] = None
7892

7993
def __init__(
8094
self,
@@ -86,6 +100,48 @@ def clamp(self, max):
86100
# Flash decoding doesn't need to clamp
87101
return self
88102

103+
def make_sliding_window_bias(
104+
self,
105+
seq_lens: List[int],
106+
window_size: Optional[int],
107+
dtype: torch.dtype,
108+
padded_input_len: Optional[int],
109+
padded_bs: Optional[int],
110+
) -> List[torch.Tensor]:
111+
attn_biases = []
112+
for seq_len in seq_lens:
113+
if seq_len != 0:
114+
tensor = torch.full(
115+
(1, seq_len, seq_len),
116+
dtype=dtype,
117+
fill_value=1,
118+
)
119+
shift = 0
120+
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
121+
if window_size is not None:
122+
mask = torch.triu(mask, diagonal=shift - window_size + 1)
123+
mask = F.pad(
124+
mask,
125+
(
126+
padded_input_len - seq_len,
127+
0,
128+
padded_input_len - seq_len,
129+
0,
130+
0,
131+
0,
132+
),
133+
value=0,
134+
)
135+
else:
136+
mask = torch.full(
137+
(1, padded_input_len, padded_input_len),
138+
dtype=dtype,
139+
fill_value=0,
140+
)
141+
attn_biases.append(mask)
142+
attn_biases = torch.stack(attn_biases, dim=0)
143+
return attn_biases.to(torch.bool)
144+
89145

90146
def _async_h2d_tensor_copy(source, device="hpu"):
91147
if source is None:
@@ -124,6 +180,7 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object:
124180
"TrimmedSeqlen",
125181
[
126182
"input_lengths",
183+
"attn_mask",
127184
],
128185
)
129186
return attention_metadata

backends/gaudi/server/text_generation_server/layers/attention/hpu.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ def attention(
9494
query,
9595
key,
9696
value,
97-
attn_mask=None,
97+
attn_mask=seqlen.attn_mask if window_size_left != -1 else None,
9898
dropout_p=0.0,
99-
is_causal=causal,
99+
is_causal=causal if window_size_left == -1 else False,
100100
scale=softmax_scale,
101101
softmax_mode="None",
102102
recompute_mode=None,
103-
valid_sequence_lengths=seqlen.input_lengths,
103+
valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,
104104
padding_side="left",
105105
)
106106
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
@@ -119,6 +119,15 @@ def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size)
119119
hpu_attention_meta = hpu_attention_meta._replace(
120120
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
121121
)
122+
if hpu_attention_meta.block_groups_in_window is not None:
123+
block_mapping = torch.nn.functional.one_hot(
124+
hpu_attention_meta.block_groups_in_window, num_classes=batch_size
125+
)
126+
attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())
127+
hpu_attention_meta = hpu_attention_meta._replace(
128+
attn_bias_in_window=attn_bias,
129+
block_mapping_in_window=block_mapping.to(dtype),
130+
)
122131
return hpu_attention_meta
123132

124133

@@ -132,17 +141,34 @@ def paged_attention(
132141
kv_scales: KVScales,
133142
softcap: Optional[float] = None,
134143
hpu_attention_meta: HPUPagedAttentionMetadata,
144+
window_size_left: int = -1,
135145
):
136146
batch_size, head_num, head_size = query.shape
137147
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
138148
output = ops.flat_pa(
139149
query=query.view(batch_size, 1, head_num * head_size),
140150
key_cache=kv_cache.key,
141151
value_cache=kv_cache.value,
142-
block_list=hpu_attention_meta.block_list,
143-
block_mapping=hpu_attention_meta.block_mapping,
144-
block_bias=hpu_attention_meta.attn_bias,
145-
block_groups=hpu_attention_meta.block_groups,
152+
block_list=(
153+
hpu_attention_meta.block_list
154+
if window_size_left == -1
155+
else hpu_attention_meta.block_list_in_window
156+
),
157+
block_mapping=(
158+
hpu_attention_meta.block_mapping
159+
if window_size_left == -1
160+
else hpu_attention_meta.block_mapping_in_window
161+
),
162+
block_bias=(
163+
hpu_attention_meta.attn_bias
164+
if window_size_left == -1
165+
else hpu_attention_meta.attn_bias_in_window
166+
),
167+
block_groups=(
168+
hpu_attention_meta.block_groups
169+
if window_size_left == -1
170+
else hpu_attention_meta.block_groups_in_window
171+
),
146172
block_size=BLOCK_SIZE,
147173
scale=softmax_scale,
148174
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def forward(
288288
softcap=self.softcap,
289289
kv_scales=self.kv_scales,
290290
hpu_attention_meta=hpu_attention_meta,
291+
window_size_left=self.window_size,
291292
)
292293

293294
return self.o_proj(

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ def __init__(
135135
self.causal = causal
136136
if is_sliding:
137137
self.window_size = config.sliding_window
138-
# TODO: remove this hack to support local sliding window
139-
config = copy.deepcopy(config)
140-
config.rope_scaling = dict(rope_type="default")
141138
self.rotary_emb = local_rotary_emb
142139
else:
143140
self.window_size = -1
@@ -267,6 +264,7 @@ def forward(
267264
softcap=self.softcap,
268265
kv_scales=self.kv_scales,
269266
hpu_attention_meta=hpu_attention_meta,
267+
window_size_left=self.window_size,
270268
)
271269

272270
return self.o_proj(
@@ -425,8 +423,10 @@ def __init__(self, prefix: str, config, weights, causal: bool):
425423
process_group = weights.process_group
426424
self.tp_rank = process_group.rank()
427425
self.tp_world_size = process_group.size()
426+
local_config = copy.deepcopy(config)
427+
local_config.rope_scaling = dict(rope_type="default")
428428
local_rotary_emb = PositionRotaryEmbedding.static(
429-
config=config,
429+
config=local_config,
430430
dim=config.head_dim,
431431
base=config.rope_local_base_freq,
432432
device=weights.device,

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def forward(
224224
seqlen,
225225
kv_scales=self.kv_scales,
226226
hpu_attention_meta=hpu_attention_meta,
227+
window_size_left=self.max_past,
227228
)
228229

229230
return self.o_proj(

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(
6262
):
6363
super().__init__()
6464
self.max_past = (
65-
config.sliding_window if config.sliding_window is not None else -1
65+
config.sliding_window
66+
if config.use_sliding_window and config.sliding_window is not None
67+
else -1
6668
)
6769
self.num_heads = config.num_attention_heads
6870
self.hidden_size = config.hidden_size
@@ -150,6 +152,7 @@ def forward(
150152
seqlen,
151153
kv_scales=self.kv_scales,
152154
hpu_attention_meta=hpu_attention_meta,
155+
window_size_left=self.max_past,
153156
)
154157

155158
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def forward(
167167
seqlen,
168168
kv_scales=self.kv_scales,
169169
hpu_attention_meta=hpu_attention_meta,
170+
window_size_left=self.max_past,
170171
)
171172

172173
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def forward(
190190
seqlen,
191191
kv_scales=self.kv_scales,
192192
hpu_attention_meta=hpu_attention_meta,
193+
window_size_left=self.max_past,
193194
)
194195

195196
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def forward(
280280
seqlen,
281281
kv_scales=self.kv_scales,
282282
hpu_attention_meta=hpu_attention_meta,
283+
window_size_left=self.max_past,
283284
)
284285

285286
return self.o_proj(

0 commit comments

Comments
 (0)