Skip to content

Commit f6ead2e

Browse files
committed
modifying decoders and attention for vllm.
removing calls into specialized attention modules. adding vllm_rpa unit test. fixing additional unit tests. adding validation support for vllm_rpa. rebasing deepseek and gpt-oss. adding skip for vllm-tpu test. addressing comments on lazy init. adding check for kv_cache and attention_metadata. adding comment on vllm_rpa. adding pyconfig deprecated validation. fixing pytype errors. adding new output type to Qwen3-Omni vision encoder. fixing deepseek batchsplit.
1 parent 8104f65 commit f6ead2e

20 files changed

+378
-87
lines changed

src/MaxText/layers/attention_mla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,9 @@ def __call__(
672672
page_state: Optional[page_manager.PageState] = None,
673673
bidirectional_mask: Optional[Any] = None,
674674
rope_kwargs: dict | None = None,
675-
) -> Array:
675+
kv_cache: Optional[Array] = None,
676+
attention_metadata: Optional[dict[str, Any]] = None,
677+
) -> tuple[Array, Optional[Array]]:
676678
"""Forward pass for MLA, reusing `AttentionOp` for the actual attention.
677679
678680
Args:
@@ -686,6 +688,8 @@ def __call__(
686688
slot: The batch slot index for paged attention.
687689
page_state: The current state of the paged attention manager.
688690
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
691+
kv_cache: Optional key-value cache used when serving models with vLLM.
692+
attention_metadata: Optional attention-related metadata used when serving models with vLLM.
689693
690694
Returns:
691695
A tensor of shape [batch, length, embed_dim] containing the
@@ -726,4 +730,4 @@ def __call__(
726730

727731
out = self.out_projection(out)
728732
out = checkpoint_name(out, "out_proj")
729-
return out
733+
return out, kv_cache

src/MaxText/layers/attention_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,11 +847,14 @@ def apply_attention(
847847
raise NotImplementedError(target_hardware)
848848
return impl(query, key, value, lengths, self.ragged_block_size)
849849

850+
# 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM
851+
# ragged paged attention kernel in `Attention.__call__`.
850852
elif (
851853
self.attention_kernel == "dot_product"
852854
or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE)
853855
or (self.attention_kernel == "autoselected" and length < 128)
854856
or (self.attention_kernel == "paged")
857+
or (self.attention_kernel == "vllm_rpa")
855858
):
856859
return self.apply_attention_dot(
857860
query,

src/MaxText/layers/attentions.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,51 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
889889
)
890890
return [prefill_kv_cache, ar_kv_cache]
891891

892+
def forward_serve_vllm(
893+
self,
894+
query: Array,
895+
key: Array,
896+
value: Array,
897+
rpa_kv_cache: list[Array] | None = None,
898+
rpa_metadata: dict[str, Any] | None = None,
899+
) -> tuple[list[Array], Array]:
900+
"""Forward function for vLLM serving with RPA attention."""
901+
try:
902+
# pylint: disable=import-outside-toplevel
903+
# pytype: disable=import-error
904+
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
905+
except ImportError as e:
906+
raise ImportError(
907+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
908+
) from e
909+
910+
if self.config.attention_sink:
911+
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")
912+
913+
if rpa_kv_cache is None or rpa_metadata is None:
914+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
915+
916+
query = query.reshape(-1, query.shape[2], query.shape[3])
917+
key = key.reshape(-1, key.shape[2], key.shape[3])
918+
value = value.reshape(-1, value.shape[2], value.shape[3])
919+
920+
attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
921+
q_scale, k_scale, v_scale = None, None, None
922+
923+
md = rpa_metadata
924+
925+
output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
926+
query,
927+
key,
928+
value,
929+
rpa_kv_cache,
930+
md.seq_lens,
931+
md.block_tables,
932+
md.query_start_loc,
933+
md.request_distribution,
934+
)
935+
return kv_cache, output
936+
892937
def __call__(
893938
self,
894939
inputs_q: Array,
@@ -904,6 +949,8 @@ def __call__(
904949
page_state: Optional[page_manager.PageState] = None,
905950
bidirectional_mask: Any = None,
906951
rope_kwargs: dict | None = None,
952+
kv_cache: Optional[Array] = None,
953+
attention_metadata: Optional[dict[str, Any]] = None,
907954
):
908955
"""Applies Attention on the input data.
909956
@@ -931,6 +978,8 @@ def __call__(
931978
slot: The batch slot index for paged attention.
932979
page_state: The current state of the paged attention manager.
933980
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
981+
kv_cache: Optional KV cache input, used when invoking from vLLM.
982+
attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.
934983
935984
Returns:
936985
output of shape `[batch, length, q_features]`.
@@ -1026,6 +1075,15 @@ def __call__(
10261075
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
10271076
)
10281077
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1078+
1079+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
1080+
batch, seq_len, num_heads, head_dim = query.shape
1081+
updated_kv, attn_out = self.forward_serve_vllm(
1082+
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
1083+
)
1084+
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)
1085+
kv_cache = updated_kv
1086+
10291087
else:
10301088
cached_values = [None, None]
10311089
if model_mode != MODEL_MODE_TRAIN:
@@ -1054,4 +1112,4 @@ def __call__(
10541112
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
10551113
out = self.out_projection(out, out_sharding=out_sharding)
10561114
out = checkpoint_name(out, "out_proj")
1057-
return out
1115+
return out, kv_cache

src/MaxText/layers/decoders.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __call__(
8787
previous_chunk=None,
8888
slot: None | int = None,
8989
page_state: None | page_manager.PageState = None,
90+
kv_cache: jax.Array | None = None,
91+
attention_metadata: dict[str, Any] | None = None,
9092
):
9193
cfg = self.config
9294
mesh = self.mesh
@@ -149,13 +151,15 @@ def __call__(
149151
model_mode=model_mode,
150152
)
151153

152-
attention_lnx = attention_layer(
154+
attention_lnx, kv_cache = attention_layer(
153155
lnx,
154156
lnx,
155157
decoder_positions,
156158
decoder_segment_ids=decoder_segment_ids,
157159
deterministic=deterministic,
158160
model_mode=model_mode,
161+
kv_cache=kv_cache,
162+
attention_metadata=attention_metadata,
159163
)
160164

161165
if model_mode == MODEL_MODE_PREFILL:
@@ -209,7 +213,10 @@ def __call__(
209213
jnp.sum(layer_output == 0) / jnp.size(layer_output),
210214
)
211215

212-
return layer_output, None if cfg.scan_layers else layer_output
216+
if cfg.scan_layers:
217+
return layer_output, None
218+
else:
219+
return layer_output, kv_cache
213220

214221

215222
class SequentialBlockDecoderLayers(nn.Module):
@@ -691,6 +698,8 @@ def __call__(
691698
bidirectional_mask: None | Any = None,
692699
image_embeddings: None | jnp.ndarray = None,
693700
image_masks: None | jnp.ndarray = None,
701+
kv_caches: list[jax.Array] | None = None,
702+
attention_metadata=None,
694703
):
695704
cfg = self.config
696705
mesh = self.mesh
@@ -844,7 +853,8 @@ def __call__(
844853
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
845854
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
846855
for index in range(num_layers):
847-
y = layer(
856+
kv_cache = kv_caches[index] if kv_caches is not None else None
857+
y, kv_cache = layer(
848858
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
849859
)(
850860
y,
@@ -855,7 +865,11 @@ def __call__(
855865
previous_chunk=previous_chunk,
856866
page_state=page_state,
857867
slot=slot,
868+
kv_cache=kv_cache,
869+
attention_metadata=attention_metadata,
858870
)
871+
if kv_caches is not None and kv_cache is not None:
872+
kv_caches[index] = kv_cache
859873
else:
860874
for lyr in range(cfg.num_decoder_layers):
861875
RemattedBlockLayer = RemattedBlockLayers[0]
@@ -877,7 +891,8 @@ def __call__(
877891
layer = RemattedBlockLayer(
878892
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
879893
)
880-
y = layer(
894+
kv_cache = kv_caches[lyr] if kv_caches is not None else None
895+
y, kv_cache = layer(
881896
y,
882897
decoder_segment_ids,
883898
decoder_positions,
@@ -886,8 +901,12 @@ def __call__(
886901
previous_chunk=previous_chunk,
887902
page_state=page_state,
888903
slot=slot,
904+
kv_cache=kv_cache,
905+
attention_metadata=attention_metadata,
889906
**layer_call_kwargs,
890907
)
908+
if kv_caches is not None and kv_cache is not None:
909+
kv_caches[lyr] = kv_cache
891910

892911
assert isinstance(y, jax.Array)
893912

@@ -904,7 +923,7 @@ def __call__(
904923

905924
# The API of the Decoder is now a tuple, providing both the main output
906925
# and the raw hidden state needed for auxiliary tasks.
907-
return logits, hidden_state
926+
return logits, hidden_state, kv_caches
908927

909928
def _apply_gemma3_scanned_blocks(
910929
self,
@@ -957,10 +976,9 @@ def _apply_gemma3_scanned_blocks(
957976
if num_remaining_layers > 0:
958977
# We name the remainder block with a 'remainder' suffix to avoid parameter name collisions
959978
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
960-
# pytype: disable=wrong-keyword-args
961979
layer = RemattedGemma3Block(
962980
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs
963-
)
981+
) # pytype: disable=wrong-keyword-args
964982
y, _ = layer(
965983
y,
966984
decoder_segment_ids,

src/MaxText/layers/deepseek.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def self_attention_with_norm(
9999
model_mode=model_mode,
100100
)
101101

102-
attention_lnx = attention_layer(
102+
attention_lnx, _ = attention_layer(
103103
lnx,
104104
lnx,
105105
decoder_positions,
@@ -127,7 +127,7 @@ def self_attention_with_norm(
127127
return hidden_states, intermediate_inputs
128128

129129

130-
def post_process(cfg, layer_output, sow):
130+
def post_process(cfg, layer_output, sow, kv_cache=None):
131131
"""postprocessing."""
132132
if cfg.record_internal_nn_metrics:
133133
sow("intermediates", "activation_mean", jnp.mean(layer_output))
@@ -141,7 +141,7 @@ def post_process(cfg, layer_output, sow):
141141
if cfg.scan_layers:
142142
return layer_output, None
143143
else:
144-
return layer_output
144+
return layer_output, kv_cache
145145

146146

147147
class DeepSeekDenseLayer(nn.Module):
@@ -163,6 +163,8 @@ def __call__(
163163
previous_chunk=None,
164164
page_state: None | page_manager.PageState = None,
165165
slot: None | int = None,
166+
kv_cache=None,
167+
attention_metadata=None,
166168
):
167169
cfg = self.config
168170
if model_mode == MODEL_MODE_PREFILL:
@@ -230,6 +232,8 @@ def __call__(
230232
previous_chunk=None,
231233
page_state: None | page_manager.PageState = None,
232234
slot: None | int = None,
235+
kv_cache=None,
236+
attention_metadata=None,
233237
):
234238
cfg = self.config
235239
if model_mode == MODEL_MODE_PREFILL:

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __call__(
5858
previous_chunk=None,
5959
page_state: None | page_manager.PageState = None,
6060
slot: None | int = None,
61+
kv_cache=None,
62+
attention_metadata=None,
6163
):
6264
x = self.with_logical_constraint(inputs)
6365
x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input")
@@ -74,7 +76,7 @@ def __call__(
7476

7577
x += self.mlp(self.post_attention_norm(x), deterministic)
7678
x = self.dropout(x, deterministic)
77-
return self.post_process(x)
79+
return self.post_process(x, kv_cache)
7880

7981
def setup(self):
8082
self.pre_attention_norm_op = self.rms_norm_layer("pre_attention_layer_norm")
@@ -177,7 +179,7 @@ def attention(
177179
previous_chunk=previous_chunk,
178180
page_state=page_state,
179181
slot=slot,
180-
)
182+
)[0]
181183
)
182184

183185
def mlp_layer(self):
@@ -194,7 +196,7 @@ def dropout(self, x, deterministic):
194196
self.dropout_op(x, deterministic=deterministic)
195197
)
196198

197-
def post_process(self, x):
199+
def post_process(self, x, kv_cache=None):
198200
"""Collect statistics about the output of the layer."""
199201
if self.config.record_internal_nn_metrics:
200202
self.sow("intermediates", "activation_mean", jnp.mean(x))
@@ -208,7 +210,7 @@ def post_process(self, x):
208210
if self.config.scan_layers:
209211
return x, None
210212
else:
211-
return x
213+
return x, kv_cache
212214

213215

214216
class DeepSeekDenseLayer(DeepSeekGenericLayer):
@@ -245,6 +247,8 @@ def __call__(
245247
previous_chunk=None,
246248
page_state: None | page_manager.PageState = None,
247249
slot: None | int = None,
250+
kv_cache=None,
251+
attention_metadata=None,
248252
split_factor: int = 2,
249253
):
250254
x = self.with_logical_constraint(inputs)
@@ -289,7 +293,7 @@ def _moe(x):
289293
x = _merge(x)
290294

291295
x = self.dropout(x, deterministic)
292-
return self.post_process(x)
296+
return self.post_process(x, kv_cache)
293297

294298
def init(self, *args, **kwargs):
295299
# Calls the parent init method for testing parity.

src/MaxText/layers/gemma.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def __call__(
129129
page_manager=None,
130130
page_state=None,
131131
slot=None,
132+
kv_cache=None,
133+
attention_metadata=None,
132134
):
133135
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
134136
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -137,13 +139,15 @@ def __call__(
137139

138140
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
139141

140-
attention_lnx = self.self_attention(
142+
attention_lnx, kv_cache = self.self_attention(
141143
lnx,
142144
lnx,
143145
decoder_positions,
144146
decoder_segment_ids=decoder_segment_ids,
145147
deterministic=deterministic,
146148
model_mode=model_mode,
149+
kv_cache=kv_cache,
150+
attention_metadata=attention_metadata,
147151
)
148152

149153
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
@@ -177,7 +181,7 @@ def __call__(
177181
if self.config.scan_layers:
178182
return layer_output, None
179183
else:
180-
return layer_output
184+
return layer_output, kv_cache
181185

182186

183187
GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class(

0 commit comments

Comments
 (0)