Skip to content

Conversation

@NicoGrande
Copy link
Collaborator

@NicoGrande NicoGrande commented Nov 6, 2025

Description

This PR introduces new optional arguments kv_cache and attention_metadata for model decoder blocks and for the attentions.py module. These arguments are provided by vLLM when executing a MaxText model from vLLM Engine and are used when calling the ragged paged attention kernel in tpu-inference.

This PR builds on the work started in #2612.

Note: This PR changes the expected method signature from Attention layers and decoder layers.

Tests

Includes a new unit-test in attention_test.py.

Additionally, end-to-end tests were performed locally on a v6e VM using the following test command:

In one process, start the vLLM server. This requires a local config.json file for the corresponding model you are trying to test from HuggingFace. Modify this file such that architectures: "MaxTextForCausalLM" is set.

HF_TOKEN=<YOUR_HF_TOKEN> TPU_BACKEND_TYPE=jax \
  python -m vllm.entrypoints.cli.main serve \
  <HF_MODEL> \
  --max-num-batched-tokens=32 \
  --max-model-len=32 \
  --max-num-seqs=1 \
  --tensor-parallel-size=4 \
  --hf_config_path=<PATH_TO_LOCAL_HF_CONFIG_JSON> \
  --additional-config='{"maxtext_config": {"model_name": "<MAXTEXT_MODEL_NAME>", "max_prefill_predict_length": 28, "max_target_length": 32, "ici_tensor_parallelism": 4, "load_parameters_path": "<MAXTEXT_CHECKPOINT_PATH>"}}'

In a second process, issue the query to the model:

curl http://localhost:8000/v1/completions \
   -H "Content-Type: application/json" \
   -d '{
       "model": "<HF_MODEL>",
       "prompt": ["Seattle is a"],
       "max_tokens": 16,
       "temperature": 0
   }'

Results for different tested models are shown below:

llama3.1-8b

output: " city that is known for its coffee culture, and it's not hard to see"

gemma3-4b

output: " vibrant city with a lot to offer, and it's a great place to"

qwen3-8b

output: " city in the state of Washington, in the Pacific Northwest region of the United States"

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch 9 times, most recently from 655760f to d0c2503 Compare November 7, 2025 19:53
@NicoGrande NicoGrande marked this pull request as ready for review November 10, 2025 18:13
@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch from d0c2503 to 91c199b Compare November 10, 2025 23:08
Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Nico,
Feel free to address the comments in a follow up PR.

@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch from 91c199b to 8f3305d Compare November 13, 2025 00:50
@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch from 8f3305d to fc3db23 Compare November 13, 2025 13:10
@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch 3 times, most recently from dbd531b to a9a75c5 Compare November 15, 2025 19:34
@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch 8 times, most recently from 2adf1bb to 575ca28 Compare November 18, 2025 01:47
removing calls into specialized attention modules.

adding vllm_rpa unit test.

fixing additional unit tests.

adding validation support for vllm_rpa.

rebasing deepseek and gpt-oss.

adding skip for vllm-tpu test.

addressing comments on lazy init.

adding check for kv_cache and attention_metadata.

adding comment on vllm_rpa.

adding pyconfig deprecated validation.

fixing pytype errors.

adding new output type to Qwen3-Omni vision encoder.

fixing deepseek batchsplit.
@NicoGrande NicoGrande force-pushed the nicogrande/update-decoders-attention-vllm branch from 575ca28 to f6ead2e Compare November 18, 2025 02:09
copybara-service bot pushed a commit that referenced this pull request Nov 18, 2025
FUTURE_COPYBARA_INTEGRATE_REVIEW=#2616 from AI-Hypercomputer:nicogrande/update-decoders-attention-vllm f6ead2e
PiperOrigin-RevId: 829115621
copybara-service bot pushed a commit that referenced this pull request Nov 18, 2025
…te-decoders-attention-vllm f6ead2e

PiperOrigin-RevId: 833932020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants