Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,9 @@ def __call__(
page_state: Optional[page_manager.PageState] = None,
bidirectional_mask: Optional[Any] = None,
rope_kwargs: dict | None = None,
) -> Array:
kv_cache: Optional[Array] = None,
attention_metadata: Optional[dict[str, Any]] = None,
) -> tuple[Array, Optional[Array]]:
"""Forward pass for MLA, reusing `AttentionOp` for the actual attention.

Args:
Expand All @@ -686,6 +688,8 @@ def __call__(
slot: The batch slot index for paged attention.
page_state: The current state of the paged attention manager.
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
kv_cache: Optional key-value cache used when serving models with vLLM.
attention_metadata: Optional attention-related metadata used when serving models with vLLM.

Returns:
A tensor of shape [batch, length, embed_dim] containing the
Expand Down Expand Up @@ -726,4 +730,4 @@ def __call__(

out = self.out_projection(out)
out = checkpoint_name(out, "out_proj")
return out
return out, kv_cache
3 changes: 3 additions & 0 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,14 @@ def apply_attention(
raise NotImplementedError(target_hardware)
return impl(query, key, value, lengths, self.ragged_block_size)

# 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM
# ragged paged attention kernel in `Attention.__call__`.
elif (
self.attention_kernel == "dot_product"
or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE)
or (self.attention_kernel == "autoselected" and length < 128)
or (self.attention_kernel == "paged")
or (self.attention_kernel == "vllm_rpa")
):
return self.apply_attention_dot(
query,
Expand Down
60 changes: 59 additions & 1 deletion src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,51 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
)
return [prefill_kv_cache, ar_kv_cache]

def forward_serve_vllm(
self,
query: Array,
key: Array,
value: Array,
rpa_kv_cache: list[Array] | None = None,
rpa_metadata: dict[str, Any] | None = None,
) -> tuple[list[Array], Array]:
"""Forward function for vLLM serving with RPA attention."""
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
except ImportError as e:
raise ImportError(
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
) from e

if self.config.attention_sink:
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")

if rpa_kv_cache is None or rpa_metadata is None:
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")

query = query.reshape(-1, query.shape[2], query.shape[3])
key = key.reshape(-1, key.shape[2], key.shape[3])
value = value.reshape(-1, value.shape[2], value.shape[3])

attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
q_scale, k_scale, v_scale = None, None, None

md = rpa_metadata

output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
query,
key,
value,
rpa_kv_cache,
md.seq_lens,
md.block_tables,
md.query_start_loc,
md.request_distribution,
)
return kv_cache, output

def __call__(
self,
inputs_q: Array,
Expand All @@ -904,6 +949,8 @@ def __call__(
page_state: Optional[page_manager.PageState] = None,
bidirectional_mask: Any = None,
rope_kwargs: dict | None = None,
kv_cache: Optional[Array] = None,
attention_metadata: Optional[dict[str, Any]] = None,
):
"""Applies Attention on the input data.

Expand Down Expand Up @@ -931,6 +978,8 @@ def __call__(
slot: The batch slot index for paged attention.
page_state: The current state of the paged attention manager.
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
kv_cache: Optional KV cache input, used when invoking from vLLM.
attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.

Returns:
output of shape `[batch, length, q_features]`.
Expand Down Expand Up @@ -1026,6 +1075,15 @@ def __call__(
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
)
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out

elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
batch, seq_len, num_heads, head_dim = query.shape
updated_kv, attn_out = self.forward_serve_vllm(
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
)
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)
kv_cache = updated_kv

else:
cached_values = [None, None]
if model_mode != MODEL_MODE_TRAIN:
Expand Down Expand Up @@ -1054,4 +1112,4 @@ def __call__(
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
out = self.out_projection(out, out_sharding=out_sharding)
out = checkpoint_name(out, "out_proj")
return out
return out, kv_cache
32 changes: 25 additions & 7 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __call__(
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
kv_cache: jax.Array | None = None,
attention_metadata: dict[str, Any] | None = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -149,13 +151,15 @@ def __call__(
model_mode=model_mode,
)

attention_lnx = attention_layer(
attention_lnx, kv_cache = attention_layer(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)

if model_mode == MODEL_MODE_PREFILL:
Expand Down Expand Up @@ -209,7 +213,10 @@ def __call__(
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)

return layer_output, None if cfg.scan_layers else layer_output
if cfg.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache


class SequentialBlockDecoderLayers(nn.Module):
Expand Down Expand Up @@ -691,6 +698,8 @@ def __call__(
bidirectional_mask: None | Any = None,
image_embeddings: None | jnp.ndarray = None,
image_masks: None | jnp.ndarray = None,
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -844,7 +853,8 @@ def __call__(
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
for index in range(num_layers):
y = layer(
kv_cache = kv_caches[index] if kv_caches is not None else None
y, kv_cache = layer(
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
)(
y,
Expand All @@ -855,7 +865,11 @@ def __call__(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if kv_caches is not None and kv_cache is not None:
kv_caches[index] = kv_cache
else:
for lyr in range(cfg.num_decoder_layers):
RemattedBlockLayer = RemattedBlockLayers[0]
Expand All @@ -877,7 +891,8 @@ def __call__(
layer = RemattedBlockLayer(
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
)
y = layer(
kv_cache = kv_caches[lyr] if kv_caches is not None else None
y, kv_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
Expand All @@ -886,8 +901,12 @@ def __call__(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
**layer_call_kwargs,
)
if kv_caches is not None and kv_cache is not None:
kv_caches[lyr] = kv_cache

assert isinstance(y, jax.Array)

Expand All @@ -904,7 +923,7 @@ def __call__(

# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
return logits, hidden_state
return logits, hidden_state, kv_caches

def _apply_gemma3_scanned_blocks(
self,
Expand Down Expand Up @@ -957,10 +976,9 @@ def _apply_gemma3_scanned_blocks(
if num_remaining_layers > 0:
# We name the remainder block with a 'remainder' suffix to avoid parameter name collisions
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
# pytype: disable=wrong-keyword-args
layer = RemattedGemma3Block(
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs
)
) # pytype: disable=wrong-keyword-args
y, _ = layer(
y,
decoder_segment_ids,
Expand Down
10 changes: 7 additions & 3 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def self_attention_with_norm(
model_mode=model_mode,
)

attention_lnx = attention_layer(
attention_lnx, _ = attention_layer(
lnx,
lnx,
decoder_positions,
Expand Down Expand Up @@ -127,7 +127,7 @@ def self_attention_with_norm(
return hidden_states, intermediate_inputs


def post_process(cfg, layer_output, sow):
def post_process(cfg, layer_output, sow, kv_cache=None):
"""postprocessing."""
if cfg.record_internal_nn_metrics:
sow("intermediates", "activation_mean", jnp.mean(layer_output))
Expand All @@ -141,7 +141,7 @@ def post_process(cfg, layer_output, sow):
if cfg.scan_layers:
return layer_output, None
else:
return layer_output
return layer_output, kv_cache


class DeepSeekDenseLayer(nn.Module):
Expand All @@ -163,6 +163,8 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
if model_mode == MODEL_MODE_PREFILL:
Expand Down Expand Up @@ -230,6 +232,8 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
if model_mode == MODEL_MODE_PREFILL:
Expand Down
14 changes: 9 additions & 5 deletions src/MaxText/layers/deepseek_batchsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
):
x = self.with_logical_constraint(inputs)
x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input")
Expand All @@ -74,7 +76,7 @@ def __call__(

x += self.mlp(self.post_attention_norm(x), deterministic)
x = self.dropout(x, deterministic)
return self.post_process(x)
return self.post_process(x, kv_cache)

def setup(self):
self.pre_attention_norm_op = self.rms_norm_layer("pre_attention_layer_norm")
Expand Down Expand Up @@ -177,7 +179,7 @@ def attention(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
)
)[0]
)

def mlp_layer(self):
Expand All @@ -194,7 +196,7 @@ def dropout(self, x, deterministic):
self.dropout_op(x, deterministic=deterministic)
)

def post_process(self, x):
def post_process(self, x, kv_cache=None):
"""Collect statistics about the output of the layer."""
if self.config.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(x))
Expand All @@ -208,7 +210,7 @@ def post_process(self, x):
if self.config.scan_layers:
return x, None
else:
return x
return x, kv_cache


class DeepSeekDenseLayer(DeepSeekGenericLayer):
Expand Down Expand Up @@ -245,6 +247,8 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
split_factor: int = 2,
):
x = self.with_logical_constraint(inputs)
Expand Down Expand Up @@ -289,7 +293,7 @@ def _moe(x):
x = _merge(x)

x = self.dropout(x, deterministic)
return self.post_process(x)
return self.post_process(x, kv_cache)

def init(self, *args, **kwargs):
# Calls the parent init method for testing parity.
Expand Down
8 changes: 6 additions & 2 deletions src/MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def __call__(
page_manager=None,
page_state=None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
Expand All @@ -137,13 +139,15 @@ def __call__(

lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)

attention_lnx = self.self_attention(
attention_lnx, kv_cache = self.self_attention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)

attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
Expand Down Expand Up @@ -177,7 +181,7 @@ def __call__(
if self.config.scan_layers:
return layer_output, None
else:
return layer_output
return layer_output, kv_cache


GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class(
Expand Down
Loading
Loading