@@ -864,6 +864,45 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
864864 )
865865 return [prefill_kv_cache , ar_kv_cache ]
866866
867+ def forward_serve_vllm (
868+ self , query : Array , key : Array , value : Array , rpa_kv_cache : list [Array ], rpa_metadata : dict [str , Any ]
869+ ) -> tuple [list [Array ], Array ]:
870+ """Forward function for vLLM serving with RPA attention."""
871+ try :
872+ # pylint: disable=import-outside-toplevel
873+ from tpu_inference .layers .jax .attention_interface import sharded_ragged_paged_attention as rpa_ops
874+ except ImportError as e :
875+ raise ImportError (
876+ "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
877+ ) from e
878+
879+ if self .config .attention_sink :
880+ raise NotImplementedError ("Attention sink is not supported in MaxText vLLM RPA attention." )
881+
882+ if rpa_kv_cache is None or rpa_metadata is None :
883+ raise ValueError ("kv_cache and attention_metadata must be provided when using vLLM." )
884+
885+ query = query .reshape (- 1 , query .shape [2 ], query .shape [3 ])
886+ key = key .reshape (- 1 , key .shape [2 ], key .shape [3 ])
887+ value = value .reshape (- 1 , value .shape [2 ], value .shape [3 ])
888+
889+ attention_chunk_size = self .config .chunk_attn_window_size if self .config .chunk_attn_window_size > 0 else None
890+ q_scale , k_scale , v_scale = None , None , None
891+
892+ md = rpa_metadata
893+
894+ output , kv_cache = rpa_ops (1.0 , self .mesh , attention_chunk_size , q_scale , k_scale , v_scale )(
895+ query ,
896+ key ,
897+ value ,
898+ rpa_kv_cache ,
899+ md .seq_lens ,
900+ md .block_tables ,
901+ md .query_start_loc ,
902+ md .request_distribution ,
903+ )
904+ return kv_cache , output
905+
867906 def __call__ (
868907 self ,
869908 inputs_q : Array ,
@@ -878,6 +917,8 @@ def __call__(
878917 slot : Optional [int ] = None ,
879918 page_state : Optional [page_manager .PageState ] = None ,
880919 bidirectional_mask : Any = None ,
920+ kv_cache : Optional [Array ] = None ,
921+ attention_metadata : Optional [dict [str , Any ]] = None ,
881922 ):
882923 """Applies Attention on the input data.
883924
@@ -905,6 +946,8 @@ def __call__(
905946 slot: The batch slot index for paged attention.
906947 page_state: The current state of the paged attention manager.
907948 bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
949+ kv_cache: Optional KV cache input, used when invoking from vLLM.
950+ attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.
908951
909952 Returns:
910953 output of shape `[batch, length, q_features]`.
@@ -1000,6 +1043,15 @@ def __call__(
10001043 query , key , value , decoder_segment_ids , model_mode , previous_chunk , slot = slot , page_state = page_state
10011044 )
10021045 out = unnormalized_out / (exp_sum + 1e-9 ) if exp_sum is not None else unnormalized_out
1046+
1047+ elif self .config .attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN :
1048+ batch , seq_len , num_heads , head_dim = query .shape
1049+ updated_kv , attn_out = self .forward_serve_vllm (
1050+ query , key , value , rpa_kv_cache = kv_cache , rpa_metadata = attention_metadata
1051+ )
1052+ out = attn_out .reshape (batch , seq_len , num_heads , head_dim )
1053+ kv_cache = updated_kv
1054+
10031055 else :
10041056 cached_values = [None , None ]
10051057 if model_mode != MODEL_MODE_TRAIN :
@@ -1028,4 +1080,4 @@ def __call__(
10281080 out = self ._maybe_shard_with_logical (out , self .decode_out_axis_names )
10291081 out = self .out_projection (out , out_sharding = out_sharding )
10301082 out = checkpoint_name (out , "out_proj" )
1031- return out
1083+ return out , kv_cache
0 commit comments