Skip to content

Commit fc3db23

Browse files
committed
modifying decoders and attention for vllm.
removing calls into specialized attention modules. adding vllm_rpa unit test. fixing additional unit tests. adding validation support for vllm_rpa. rebasing deepseek and gpt-oss. adding skip for vllm-tpu test. addressing comments on lazy init. adding check for kv_cache and attention_metadata.
1 parent 188c426 commit fc3db23

File tree

18 files changed

+341
-69
lines changed

18 files changed

+341
-69
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ def apply_attention(
852852
or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE)
853853
or (self.attention_kernel == "autoselected" and length < 128)
854854
or (self.attention_kernel == "paged")
855+
or (self.attention_kernel == "vllm_rpa")
855856
):
856857
return self.apply_attention_dot(
857858
query,

src/MaxText/layers/attentions.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,45 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
864864
)
865865
return [prefill_kv_cache, ar_kv_cache]
866866

867+
def forward_serve_vllm(
868+
self, query: Array, key: Array, value: Array, rpa_kv_cache: list[Array], rpa_metadata: dict[str, Any]
869+
) -> tuple[list[Array], Array]:
870+
"""Forward function for vLLM serving with RPA attention."""
871+
try:
872+
# pylint: disable=import-outside-toplevel
873+
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
874+
except ImportError as e:
875+
raise ImportError(
876+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
877+
) from e
878+
879+
if self.config.attention_sink:
880+
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")
881+
882+
if rpa_kv_cache is None or rpa_metadata is None:
883+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
884+
885+
query = query.reshape(-1, query.shape[2], query.shape[3])
886+
key = key.reshape(-1, key.shape[2], key.shape[3])
887+
value = value.reshape(-1, value.shape[2], value.shape[3])
888+
889+
attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
890+
q_scale, k_scale, v_scale = None, None, None
891+
892+
md = rpa_metadata
893+
894+
output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
895+
query,
896+
key,
897+
value,
898+
rpa_kv_cache,
899+
md.seq_lens,
900+
md.block_tables,
901+
md.query_start_loc,
902+
md.request_distribution,
903+
)
904+
return kv_cache, output
905+
867906
def __call__(
868907
self,
869908
inputs_q: Array,
@@ -878,6 +917,8 @@ def __call__(
878917
slot: Optional[int] = None,
879918
page_state: Optional[page_manager.PageState] = None,
880919
bidirectional_mask: Any = None,
920+
kv_cache: Optional[Array] = None,
921+
attention_metadata: Optional[dict[str, Any]] = None,
881922
):
882923
"""Applies Attention on the input data.
883924
@@ -905,6 +946,8 @@ def __call__(
905946
slot: The batch slot index for paged attention.
906947
page_state: The current state of the paged attention manager.
907948
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
949+
kv_cache: Optional KV cache input, used when invoking from vLLM.
950+
attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.
908951
909952
Returns:
910953
output of shape `[batch, length, q_features]`.
@@ -1000,6 +1043,15 @@ def __call__(
10001043
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
10011044
)
10021045
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1046+
1047+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
1048+
batch, seq_len, num_heads, head_dim = query.shape
1049+
updated_kv, attn_out = self.forward_serve_vllm(
1050+
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
1051+
)
1052+
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)
1053+
kv_cache = updated_kv
1054+
10031055
else:
10041056
cached_values = [None, None]
10051057
if model_mode != MODEL_MODE_TRAIN:
@@ -1028,4 +1080,4 @@ def __call__(
10281080
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
10291081
out = self.out_projection(out, out_sharding=out_sharding)
10301082
out = checkpoint_name(out, "out_proj")
1031-
return out
1083+
return out, kv_cache

src/MaxText/layers/decoders.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __call__(
8787
previous_chunk=None,
8888
slot: None | int = None,
8989
page_state: None | page_manager.PageState = None,
90+
kv_cache: jax.Array | None = None,
91+
attention_metadata: dict[str, Any] | None = None,
9092
):
9193
cfg = self.config
9294
mesh = self.mesh
@@ -149,13 +151,15 @@ def __call__(
149151
model_mode=model_mode,
150152
)
151153

152-
attention_lnx = attention_layer(
154+
attention_lnx, kv_cache = attention_layer(
153155
lnx,
154156
lnx,
155157
decoder_positions,
156158
decoder_segment_ids=decoder_segment_ids,
157159
deterministic=deterministic,
158160
model_mode=model_mode,
161+
kv_cache=kv_cache,
162+
attention_metadata=attention_metadata,
159163
)
160164

161165
if model_mode == MODEL_MODE_PREFILL:
@@ -209,7 +213,7 @@ def __call__(
209213
jnp.sum(layer_output == 0) / jnp.size(layer_output),
210214
)
211215

212-
return layer_output, None if cfg.scan_layers else layer_output
216+
return layer_output, None if cfg.scan_layers else layer_output, kv_cache
213217

214218

215219
class SequentialBlockDecoderLayers(nn.Module):
@@ -684,6 +688,8 @@ def __call__(
684688
bidirectional_mask: None | Any = None,
685689
image_embeddings: None | jnp.ndarray = None,
686690
image_masks: None | jnp.ndarray = None,
691+
kv_caches: list[jax.Array] | None = None,
692+
attention_metadata=None,
687693
):
688694
cfg = self.config
689695
mesh = self.mesh
@@ -837,7 +843,8 @@ def __call__(
837843
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
838844
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
839845
for index in range(num_layers):
840-
y = layer(
846+
kv_cache = kv_caches[index] if kv_caches is not None else None
847+
y, kv_cache = layer(
841848
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
842849
)(
843850
y,
@@ -848,7 +855,11 @@ def __call__(
848855
previous_chunk=previous_chunk,
849856
page_state=page_state,
850857
slot=slot,
858+
kv_cache=kv_cache,
859+
attention_metadata=attention_metadata,
851860
)
861+
if kv_caches is not None:
862+
kv_caches[index] = kv_cache
852863
else:
853864
for lyr in range(cfg.num_decoder_layers):
854865
RemattedBlockLayer = RemattedBlockLayers[0]
@@ -870,7 +881,8 @@ def __call__(
870881
layer = RemattedBlockLayer(
871882
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
872883
)
873-
y = layer(
884+
kv_cache = kv_caches[lyr] if kv_caches is not None else None
885+
y, kv_cache = layer(
874886
y,
875887
decoder_segment_ids,
876888
decoder_positions,
@@ -879,8 +891,12 @@ def __call__(
879891
previous_chunk=previous_chunk,
880892
page_state=page_state,
881893
slot=slot,
894+
kv_cache=kv_cache,
895+
attention_metadata=attention_metadata,
882896
**layer_call_kwargs,
883897
)
898+
if kv_caches is not None:
899+
kv_caches[lyr] = kv_cache
884900

885901
assert isinstance(y, jax.Array)
886902

@@ -897,7 +913,7 @@ def __call__(
897913

898914
# The API of the Decoder is now a tuple, providing both the main output
899915
# and the raw hidden state needed for auxiliary tasks.
900-
return logits, hidden_state
916+
return logits, hidden_state, kv_caches
901917

902918
def _apply_gemma3_scanned_blocks(
903919
self,

src/MaxText/layers/deepseek.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def self_attention_with_norm(
127127
return hidden_states, intermediate_inputs
128128

129129

130-
def post_process(cfg, layer_output, sow):
130+
def post_process(cfg, layer_output, sow, kv_cache=None):
131131
"""postprocessing."""
132132
if cfg.record_internal_nn_metrics:
133133
sow("intermediates", "activation_mean", jnp.mean(layer_output))
@@ -141,7 +141,7 @@ def post_process(cfg, layer_output, sow):
141141
if cfg.scan_layers:
142142
return layer_output, None
143143
else:
144-
return layer_output
144+
return layer_output, kv_cache
145145

146146

147147
class DeepSeekDenseLayer(nn.Module):
@@ -163,6 +163,8 @@ def __call__(
163163
previous_chunk=None,
164164
page_state: None | page_manager.PageState = None,
165165
slot: None | int = None,
166+
kv_cache=None,
167+
attention_metadata=None,
166168
):
167169
cfg = self.config
168170
if model_mode == MODEL_MODE_PREFILL:
@@ -230,6 +232,8 @@ def __call__(
230232
previous_chunk=None,
231233
page_state: None | page_manager.PageState = None,
232234
slot: None | int = None,
235+
kv_cache=None,
236+
attention_metadata=None,
233237
):
234238
cfg = self.config
235239
if model_mode == MODEL_MODE_PREFILL:

src/MaxText/layers/gemma.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def __call__(
129129
page_manager=None,
130130
page_state=None,
131131
slot=None,
132+
kv_cache=None,
133+
attention_metadata=None,
132134
):
133135
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
134136
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -137,13 +139,15 @@ def __call__(
137139

138140
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
139141

140-
attention_lnx = self.self_attention(
142+
attention_lnx, kv_cache = self.self_attention(
141143
lnx,
142144
lnx,
143145
decoder_positions,
144146
decoder_segment_ids=decoder_segment_ids,
145147
deterministic=deterministic,
146148
model_mode=model_mode,
149+
kv_cache=kv_cache,
150+
attention_metadata=attention_metadata,
147151
)
148152

149153
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
@@ -177,7 +181,7 @@ def __call__(
177181
if self.config.scan_layers:
178182
return layer_output, None
179183
else:
180-
return layer_output
184+
return layer_output, kv_cache
181185

182186

183187
GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class(

src/MaxText/layers/gemma2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,24 @@ def __call__(
223223
previous_chunk=None,
224224
page_state=None,
225225
slot=None,
226+
kv_cache=None,
227+
attention_metadata=None,
226228
):
227229
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
228230
inputs = checkpoint_name(inputs, "decoder_layer_input")
229231
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
230232
lnx = self.pre_self_attention_norm_local(inputs)
231233
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
232234

233-
attention_lnx = self.self_attention_local(
235+
attention_lnx, kv_cache = self.self_attention_local(
234236
lnx,
235237
lnx,
236238
decoder_positions,
237239
decoder_segment_ids=decoder_segment_ids,
238240
deterministic=deterministic,
239241
model_mode=model_mode,
242+
kv_cache=kv_cache,
243+
attention_metadata=attention_metadata,
240244
)
241245
if self.config.use_post_attn_norm:
242246
attention_lnx = self.post_self_attention_norm_local(attention_lnx)
@@ -311,7 +315,7 @@ def __call__(
311315
if self.config.scan_layers:
312316
return layer_output, None
313317
else:
314-
return layer_output
318+
return layer_output, kv_cache
315319

316320

317321
Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class(

src/MaxText/layers/gemma3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def __call__(
189189
page_state=None,
190190
slot=None,
191191
bidirectional_mask=None,
192+
kv_cache=None,
193+
attention_metadata=None,
192194
):
193195
cfg = self.config
194196
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
@@ -198,14 +200,16 @@ def __call__(
198200
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
199201

200202
# Self-attention block
201-
attention_lnx = self.self_attention(
203+
attention_lnx, kv_cache = self.self_attention(
202204
lnx,
203205
lnx,
204206
decoder_positions,
205207
decoder_segment_ids=decoder_segment_ids,
206208
deterministic=deterministic,
207209
model_mode=model_mode,
208210
bidirectional_mask=bidirectional_mask,
211+
kv_cache=kv_cache,
212+
attention_metadata=attention_metadata,
209213
)
210214
if cfg.use_post_attn_norm:
211215
attention_lnx = self.post_self_attention_norm(attention_lnx)
@@ -240,7 +244,7 @@ def __call__(
240244
if cfg.scan_layers:
241245
return layer_output, None
242246
else:
243-
return layer_output
247+
return layer_output, kv_cache
244248

245249

246250
Gemma3DecoderLayerToLinen = nnx_wrappers.to_linen_class(

src/MaxText/layers/gpt3.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def __call__(
271271
*,
272272
model_mode: str = MODEL_MODE_TRAIN,
273273
deterministic: bool = False,
274+
kv_cache: Array = None,
275+
attention_metadata: dict[str, Any] = None,
274276
):
275277
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
276278
if self.fused_qkv:
@@ -312,7 +314,7 @@ def __call__(
312314
# apply output projection, output dim is set to the input dim.
313315
out = self.out_projection(inputs_q.shape[-1], out)
314316
out = checkpoint_name(out, "out_proj")
315-
return out
317+
return out, kv_cache
316318

317319

318320
# -----------------------------------------
@@ -339,6 +341,8 @@ def __call__(
339341
previous_chunk=None,
340342
page_state=None,
341343
slot=None,
344+
kv_cache=None,
345+
attention_metadata=None,
342346
):
343347
cfg = self.config
344348
mesh = self.mesh
@@ -381,8 +385,13 @@ def __call__(
381385
kv_quant=quantizations.configure_kv_quant(cfg),
382386
)
383387

384-
attention_lnx = attention_layer(
385-
lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic
388+
attention_lnx, kv_cache = attention_layer(
389+
lnx,
390+
decoder_segment_ids=decoder_segment_ids,
391+
model_mode=model_mode,
392+
deterministic=deterministic,
393+
kv_cache=kv_cache,
394+
attention_metadata=attention_metadata,
386395
)
387396

388397
attention_lnx = nn.with_logical_constraint(
@@ -428,4 +437,4 @@ def __call__(
428437
if cfg.scan_layers:
429438
return layer_output, None
430439
else:
431-
return layer_output
440+
return layer_output, kv_cache

0 commit comments

Comments
 (0)