Skip to content

Commit 2adf1bb

Browse files
committed
modifying decoders and attention for vllm.
removing calls into specialized attention modules. adding vllm_rpa unit test. fixing additional unit tests. adding validation support for vllm_rpa. rebasing deepseek and gpt-oss. adding skip for vllm-tpu test. addressing comments on lazy init. adding check for kv_cache and attention_metadata. adding comment on vllm_rpa. adding pyconfig deprecated validation. fixing pytype errors. adding new output type to Qwen3-Omni vision encoder. fixing deepseek batchsplit.
1 parent 8104f65 commit 2adf1bb

20 files changed

+370
-83
lines changed

src/MaxText/layers/attention_mla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,9 @@ def __call__(
672672
page_state: Optional[page_manager.PageState] = None,
673673
bidirectional_mask: Optional[Any] = None,
674674
rope_kwargs: dict | None = None,
675-
) -> Array:
675+
kv_cache: Optional[Array] = None,
676+
attention_metadata: Optional[dict[str, Any]] = None,
677+
) -> tuple[Array, Optional[Array]]:
676678
"""Forward pass for MLA, reusing `AttentionOp` for the actual attention.
677679
678680
Args:
@@ -686,6 +688,8 @@ def __call__(
686688
slot: The batch slot index for paged attention.
687689
page_state: The current state of the paged attention manager.
688690
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
691+
kv_cache: Optional key-value cache used when serving models with vLLM.
692+
attention_metadata: Optional attention-related metadata used when serving models with vLLM.
689693
690694
Returns:
691695
A tensor of shape [batch, length, embed_dim] containing the
@@ -726,4 +730,4 @@ def __call__(
726730

727731
out = self.out_projection(out)
728732
out = checkpoint_name(out, "out_proj")
729-
return out
733+
return out, kv_cache

src/MaxText/layers/attention_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,11 +847,14 @@ def apply_attention(
847847
raise NotImplementedError(target_hardware)
848848
return impl(query, key, value, lengths, self.ragged_block_size)
849849

850+
# 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM
851+
# ragged paged attention kernel in `Attention.__call__`.
850852
elif (
851853
self.attention_kernel == "dot_product"
852854
or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE)
853855
or (self.attention_kernel == "autoselected" and length < 128)
854856
or (self.attention_kernel == "paged")
857+
or (self.attention_kernel == "vllm_rpa")
855858
):
856859
return self.apply_attention_dot(
857860
query,

src/MaxText/layers/attentions.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,51 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
889889
)
890890
return [prefill_kv_cache, ar_kv_cache]
891891

892+
def forward_serve_vllm(
893+
self,
894+
query: Array,
895+
key: Array,
896+
value: Array,
897+
rpa_kv_cache: list[Array] | None = None,
898+
rpa_metadata: dict[str, Any] | None = None,
899+
) -> tuple[list[Array], Array]:
900+
"""Forward function for vLLM serving with RPA attention."""
901+
try:
902+
# pylint: disable=import-outside-toplevel
903+
# pytype: disable=import-error
904+
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
905+
except ImportError as e:
906+
raise ImportError(
907+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
908+
) from e
909+
910+
if self.config.attention_sink:
911+
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")
912+
913+
if rpa_kv_cache is None or rpa_metadata is None:
914+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
915+
916+
query = query.reshape(-1, query.shape[2], query.shape[3])
917+
key = key.reshape(-1, key.shape[2], key.shape[3])
918+
value = value.reshape(-1, value.shape[2], value.shape[3])
919+
920+
attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
921+
q_scale, k_scale, v_scale = None, None, None
922+
923+
md = rpa_metadata
924+
925+
output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
926+
query,
927+
key,
928+
value,
929+
rpa_kv_cache,
930+
md.seq_lens,
931+
md.block_tables,
932+
md.query_start_loc,
933+
md.request_distribution,
934+
)
935+
return kv_cache, output
936+
892937
def __call__(
893938
self,
894939
inputs_q: Array,
@@ -904,6 +949,8 @@ def __call__(
904949
page_state: Optional[page_manager.PageState] = None,
905950
bidirectional_mask: Any = None,
906951
rope_kwargs: dict | None = None,
952+
kv_cache: Optional[Array] = None,
953+
attention_metadata: Optional[dict[str, Any]] = None,
907954
):
908955
"""Applies Attention on the input data.
909956
@@ -931,6 +978,8 @@ def __call__(
931978
slot: The batch slot index for paged attention.
932979
page_state: The current state of the paged attention manager.
933980
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
981+
kv_cache: Optional KV cache input, used when invoking from vLLM.
982+
attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.
934983
935984
Returns:
936985
output of shape `[batch, length, q_features]`.
@@ -1026,6 +1075,15 @@ def __call__(
10261075
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
10271076
)
10281077
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1078+
1079+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
1080+
batch, seq_len, num_heads, head_dim = query.shape
1081+
updated_kv, attn_out = self.forward_serve_vllm(
1082+
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
1083+
)
1084+
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)
1085+
kv_cache = updated_kv
1086+
10291087
else:
10301088
cached_values = [None, None]
10311089
if model_mode != MODEL_MODE_TRAIN:
@@ -1054,4 +1112,4 @@ def __call__(
10541112
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
10551113
out = self.out_projection(out, out_sharding=out_sharding)
10561114
out = checkpoint_name(out, "out_proj")
1057-
return out
1115+
return out, kv_cache

src/MaxText/layers/decoders.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __call__(
8787
previous_chunk=None,
8888
slot: None | int = None,
8989
page_state: None | page_manager.PageState = None,
90+
kv_cache: jax.Array | None = None,
91+
attention_metadata: dict[str, Any] | None = None,
9092
):
9193
cfg = self.config
9294
mesh = self.mesh
@@ -149,13 +151,15 @@ def __call__(
149151
model_mode=model_mode,
150152
)
151153

152-
attention_lnx = attention_layer(
154+
attention_lnx, kv_cache = attention_layer(
153155
lnx,
154156
lnx,
155157
decoder_positions,
156158
decoder_segment_ids=decoder_segment_ids,
157159
deterministic=deterministic,
158160
model_mode=model_mode,
161+
kv_cache=kv_cache,
162+
attention_metadata=attention_metadata,
159163
)
160164

161165
if model_mode == MODEL_MODE_PREFILL:
@@ -209,7 +213,10 @@ def __call__(
209213
jnp.sum(layer_output == 0) / jnp.size(layer_output),
210214
)
211215

212-
return layer_output, None if cfg.scan_layers else layer_output
216+
if cfg.scan_layers:
217+
return layer_output, None
218+
else:
219+
return layer_output, kv_cache
213220

214221

215222
class SequentialBlockDecoderLayers(nn.Module):
@@ -691,6 +698,8 @@ def __call__(
691698
bidirectional_mask: None | Any = None,
692699
image_embeddings: None | jnp.ndarray = None,
693700
image_masks: None | jnp.ndarray = None,
701+
kv_caches: list[jax.Array] | None = None,
702+
attention_metadata=None,
694703
):
695704
cfg = self.config
696705
mesh = self.mesh
@@ -844,7 +853,8 @@ def __call__(
844853
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
845854
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
846855
for index in range(num_layers):
847-
y = layer(
856+
kv_cache = kv_caches[index] if kv_caches is not None else None
857+
y, kv_cache = layer(
848858
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
849859
)(
850860
y,
@@ -855,7 +865,11 @@ def __call__(
855865
previous_chunk=previous_chunk,
856866
page_state=page_state,
857867
slot=slot,
868+
kv_cache=kv_cache,
869+
attention_metadata=attention_metadata,
858870
)
871+
if kv_caches is not None and kv_cache is not None:
872+
kv_caches[index] = kv_cache
859873
else:
860874
for lyr in range(cfg.num_decoder_layers):
861875
RemattedBlockLayer = RemattedBlockLayers[0]
@@ -877,7 +891,8 @@ def __call__(
877891
layer = RemattedBlockLayer(
878892
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
879893
)
880-
y = layer(
894+
kv_cache = kv_caches[lyr] if kv_caches is not None else None
895+
y, kv_cache = layer(
881896
y,
882897
decoder_segment_ids,
883898
decoder_positions,
@@ -886,8 +901,12 @@ def __call__(
886901
previous_chunk=previous_chunk,
887902
page_state=page_state,
888903
slot=slot,
904+
kv_cache=kv_cache,
905+
attention_metadata=attention_metadata,
889906
**layer_call_kwargs,
890907
)
908+
if kv_caches is not None and kv_cache is not None:
909+
kv_caches[lyr] = kv_cache
891910

892911
assert isinstance(y, jax.Array)
893912

@@ -904,7 +923,7 @@ def __call__(
904923

905924
# The API of the Decoder is now a tuple, providing both the main output
906925
# and the raw hidden state needed for auxiliary tasks.
907-
return logits, hidden_state
926+
return logits, hidden_state, kv_caches
908927

909928
def _apply_gemma3_scanned_blocks(
910929
self,
@@ -957,10 +976,9 @@ def _apply_gemma3_scanned_blocks(
957976
if num_remaining_layers > 0:
958977
# We name the remainder block with a 'remainder' suffix to avoid parameter name collisions
959978
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
960-
# pytype: disable=wrong-keyword-args
961979
layer = RemattedGemma3Block(
962980
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs
963-
)
981+
) # pytype: disable=wrong-keyword-args
964982
y, _ = layer(
965983
y,
966984
decoder_segment_ids,

src/MaxText/layers/deepseek.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def self_attention_with_norm(
9999
model_mode=model_mode,
100100
)
101101

102-
attention_lnx = attention_layer(
102+
attention_lnx, _ = attention_layer(
103103
lnx,
104104
lnx,
105105
decoder_positions,
@@ -127,7 +127,7 @@ def self_attention_with_norm(
127127
return hidden_states, intermediate_inputs
128128

129129

130-
def post_process(cfg, layer_output, sow):
130+
def post_process(cfg, layer_output, sow, kv_cache=None):
131131
"""postprocessing."""
132132
if cfg.record_internal_nn_metrics:
133133
sow("intermediates", "activation_mean", jnp.mean(layer_output))
@@ -141,7 +141,7 @@ def post_process(cfg, layer_output, sow):
141141
if cfg.scan_layers:
142142
return layer_output, None
143143
else:
144-
return layer_output
144+
return layer_output, kv_cache
145145

146146

147147
class DeepSeekDenseLayer(nn.Module):
@@ -163,6 +163,8 @@ def __call__(
163163
previous_chunk=None,
164164
page_state: None | page_manager.PageState = None,
165165
slot: None | int = None,
166+
kv_cache=None,
167+
attention_metadata=None,
166168
):
167169
cfg = self.config
168170
if model_mode == MODEL_MODE_PREFILL:
@@ -230,6 +232,8 @@ def __call__(
230232
previous_chunk=None,
231233
page_state: None | page_manager.PageState = None,
232234
slot: None | int = None,
235+
kv_cache=None,
236+
attention_metadata=None,
233237
):
234238
cfg = self.config
235239
if model_mode == MODEL_MODE_PREFILL:

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def attention(
177177
previous_chunk=previous_chunk,
178178
page_state=page_state,
179179
slot=slot,
180-
)
180+
)[0]
181181
)
182182

183183
def mlp_layer(self):

src/MaxText/layers/gemma.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def __call__(
129129
page_manager=None,
130130
page_state=None,
131131
slot=None,
132+
kv_cache=None,
133+
attention_metadata=None,
132134
):
133135
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
134136
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -137,13 +139,15 @@ def __call__(
137139

138140
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
139141

140-
attention_lnx = self.self_attention(
142+
attention_lnx, kv_cache = self.self_attention(
141143
lnx,
142144
lnx,
143145
decoder_positions,
144146
decoder_segment_ids=decoder_segment_ids,
145147
deterministic=deterministic,
146148
model_mode=model_mode,
149+
kv_cache=kv_cache,
150+
attention_metadata=attention_metadata,
147151
)
148152

149153
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
@@ -177,7 +181,7 @@ def __call__(
177181
if self.config.scan_layers:
178182
return layer_output, None
179183
else:
180-
return layer_output
184+
return layer_output, kv_cache
181185

182186

183187
GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class(

src/MaxText/layers/gemma2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,24 @@ def __call__(
223223
previous_chunk=None,
224224
page_state=None,
225225
slot=None,
226+
kv_cache=None,
227+
attention_metadata=None,
226228
):
227229
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
228230
inputs = checkpoint_name(inputs, "decoder_layer_input")
229231
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
230232
lnx = self.pre_self_attention_norm_local(inputs)
231233
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
232234

233-
attention_lnx = self.self_attention_local(
235+
attention_lnx, kv_cache = self.self_attention_local(
234236
lnx,
235237
lnx,
236238
decoder_positions,
237239
decoder_segment_ids=decoder_segment_ids,
238240
deterministic=deterministic,
239241
model_mode=model_mode,
242+
kv_cache=kv_cache,
243+
attention_metadata=attention_metadata,
240244
)
241245
if self.config.use_post_attn_norm:
242246
attention_lnx = self.post_self_attention_norm_local(attention_lnx)
@@ -268,7 +272,7 @@ def __call__(
268272
lnx = self.pre_self_attention_norm_global(inputs)
269273
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
270274

271-
attention_lnx = self.self_attention_global(
275+
attention_lnx, kv_cache = self.self_attention_global(
272276
lnx,
273277
lnx,
274278
decoder_positions,
@@ -311,7 +315,7 @@ def __call__(
311315
if self.config.scan_layers:
312316
return layer_output, None
313317
else:
314-
return layer_output
318+
return layer_output, kv_cache
315319

316320

317321
Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class(

0 commit comments

Comments
 (0)