diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3152cd6488f3..7f841fbb7ce4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -450,6 +450,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_compile_ranges.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 @@ -471,8 +472,8 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph.py - # Limit to no custom ops to reduce running time + - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time # Wrap with quotes to escape yaml and avoid starting -k string with a - - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" @@ -951,10 +952,13 @@ steps: - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/test_fusions_e2e.py + - tests/compile/test_full_graph.py commands: - nvidia-smi # Run all e2e fusion tests - pytest -v -s tests/compile/test_fusions_e2e.py + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 29ef6409bb16..074b7a440b61 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,97 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate +import itertools -import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope -from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser +batch_size_range = [2**i for i in range(0, 8, 2)] +seq_len_range = [2**i for i in range(6, 10, 1)] +num_heads_range = [32, 48] +configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range)) -def benchmark_rope_kernels_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: int | None, - dtype: torch.dtype, - seed: int, - device: str, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - # silulating serving 4 LoRAs - scaling_factors = [1, 2, 4, 8] - # batched RoPE can take multiple scaling factors - batched_rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": tuple(scaling_factors)}, + +def get_benchmark(head_size, rotary_dim, is_neox_style, device): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "num_heads"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "flashinfer", "vllm"], + line_names=["PyTorch", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}", + args={}, + ) ) - # non-batched RoPE takes only one scaling factor, we create multiple - # instances to simulate the same behavior - non_batched_ropes: list[RotaryEmbedding] = [] - for scaling_factor in scaling_factors: - non_batched_ropes.append( - get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": (scaling_factor,)}, - ) + def benchmark(batch_size, seq_len, num_heads, provider): + dtype = torch.bfloat16 + max_position = 8192 + base = 10000 + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope = rope.to(dtype=dtype, device=device) + cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) + + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + query = torch.randn( + (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device ) + key = torch.randn_like(query) - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + quantiles = [0.5, 0.2, 0.8] - # create query offsets for batched RoPE, we concat multiple kv cache - # together and each query needs to find the right kv cache of its type - offset_map = torch.tensor( - list( - accumulate( - [0] - + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_native(positions, query.clone(), key.clone()), + quantiles=quantiles, ) - ) - ) - query_types = torch.randint( - 0, len(scaling_factors), (batch_size, seq_len), device=device - ) - # map query types to offsets - query_offsets = offset_map[query_types] - # the kernel takes flattened offsets - flatten_offsets = query_offsets.flatten() + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox_style, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_cuda(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms - # batched queries of the same type together for non-batched RoPE - queries = [query[query_types == i] for i in range(len(scaling_factors))] - keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qkr = zip(queries, keys, non_batched_ropes) - # synchronize before start timing - torch.cuda.synchronize() - with nvtx.annotate("non-batched", color="yellow"): - for q, k, r in packed_qkr: - r.forward(positions, q, k) - torch.cuda.synchronize() - with nvtx.annotate("batched", color="green"): - batched_rope.forward(positions, query, key, flatten_offsets) - torch.cuda.synchronize() + return benchmark if __name__ == "__main__": @@ -116,17 +95,12 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument( "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" ) + parser.add_argument("--save-path", type=str, default="./configs/rope/") args = parser.parse_args() - print(args) - benchmark_rope_kernels_multi_lora( - is_neox_style=args.is_neox_style, - batch_size=args.batch_size, - seq_len=args.seq_len, - num_heads=args.num_heads, - head_size=args.head_size, - rotary_dim=args.rotary_dim, - dtype=getattr(torch, args.dtype), - seed=args.seed, - device=args.device, + # Get the benchmark function + benchmark = get_benchmark( + args.head_size, args.rotary_dim, args.is_neox_style, args.device ) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 931090db50e9..29db9fa273a4 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54 + GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 6fcd246f63c5..2521b2797e2c 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant( // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. + static constexpr int GROUP_SIZE = 128; + TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(y_s.dtype() == torch::kFloat32); - TORCH_CHECK(input.size(-1) % 256 == 0); + TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); using Idx_t = int64_t; @@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_counts_e = tokens_per_expert.stride(0); - static constexpr int GROUP_SIZE = 128; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ @@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant( static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + int const NUM_GROUPS = H / GROUP_SIZE; if (!use_ue8m0) { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); } } else { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); } diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 633e23eea33e..ee224e6922fb 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | -| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] | +| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md index e7dd9fee12d3..edcbaa716447 100644 --- a/docs/features/sleep_mode.md +++ b/docs/features/sleep_mode.md @@ -13,6 +13,9 @@ Key benefits: !!! note This feature is only supported on CUDA platform. +!!! note + For more information, see this [Blog Post](https://blog.vllm.ai/2025/10/26/sleep-mode.html). + ## Sleep levels Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update. @@ -31,6 +34,7 @@ llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True) #### Python API ```python +# Sleep level 1 # Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache) llm.sleep(level=1) @@ -38,6 +42,21 @@ llm.sleep(level=1) llm.wake_up() ``` +```python +# Sleep level 2 +# Put the engine to sleep (level=2: discard both weights and KV cache) +llm.sleep(level=2) + +# Reallocate weights memory only +llm.wake_up(tags=["weights"]) + +# Load weights in-place +llm.collective_rpc("reload_weights") + +# Reallocate KV cache +llm.wake_up(tags=["kv_cache"]) +``` + #### RLHF weight updates During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations. @@ -69,10 +88,30 @@ VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ --port 8000 ``` +Below is an example of how to sleep and wake up a model in level 1. + +```bash +curl -X POST 'http://localhost:8000/sleep?level=1' +curl -X POST 'http://localhost:8000/wake_up' +``` + +And this is an example of how to sleep and wake up a model in level 2. + +```bash +curl -X POST 'http://localhost:8000/sleep?level=2' +# Reallocate weights memory only +curl -X POST 'http://localhost:8000/wake_up?tags=weights' +# Load weights in-place +curl -X POST 'http://localhost:8000/collective_rpc' -H 'Content-Type: application/json' -d '{"method":"reload_weights"}' +# Reallocate KV cache +curl -X POST 'http://localhost:8000/wake_up?tags=kv_cache' +``` + #### HTTP endpoints - `POST /sleep?level=1` — Put the model to sleep (`level=1`). - `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`). +- `POST /collective_rpc` — Perform a collective remote procedure call (RPC). - `GET /is_sleeping` — Check if the model is sleeping. !!! note diff --git a/requirements/common.txt b/requirements/common.txt index 8009581f62a4..90efb79a845d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -49,3 +49,4 @@ cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 +model-hosting-container-standards < 1.0.0 \ No newline at end of file diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..bacaa48ae477 --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import fx as fx +from torch import nn + +# This import automatically registers `torch.ops.silly.attention` +import tests.compile.silly_attention # noqa +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.inductor_pass import ( + InductorPass, + get_pass_context, +) +from vllm.config import ( + VllmConfig, + set_current_vllm_config, +) +from vllm.config.compilation import CompilationConfig, CompilationMode +from vllm.config.scheduler import SchedulerConfig +from vllm.config.utils import Range +from vllm.forward_context import set_forward_context + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +@support_torch_compile +class TestModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE).cuda()) + + +class PostGradPassManagerCheckRanges(InductorPass): + def __init__(self, ranges: list[Range]): + self.ranges = ranges + self.num_calls = 0 + + def __call__(self, graph: fx.Graph): + compile_range = get_pass_context().compile_range + assert compile_range in self.ranges, ( + f"Compile range {compile_range} not in {self.ranges}" + ) + self.num_calls += 1 + + def uuid(self) -> str: + state = { + "ranges": [str(range) for range in self.ranges], + "current_compile_range": str(get_pass_context().compile_range), + } + return InductorPass.hash_dict(state) + + +def test_compile_ranges(): + post_grad_pass_manager = PostGradPassManagerCheckRanges( + [ + Range(start=1, end=8), + Range(start=8, end=32), + Range(start=32, end=8193), + ] + ) + vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8, 32], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_pass_manager, + # Disable inductor cache to get the number of passes correctly + "force_disable_caches": True, + }, + ), + ) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() + batch_sizes = [1, 4, 16, 24, 48, 64] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=3, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_pass_manager.num_calls == 3 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0ad8c17d8668..71f90f6d8d3e 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -183,8 +183,14 @@ def test_custom_compile_config( "compilation_mode", [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) -def test_fp8_kv_scale_compile(compilation_mode: int): - model = "Qwen/Qwen2-0.5B" +@pytest.mark.parametrize( + "model", + [ + "Qwen/Qwen2-0.5B", # Standard attention model + "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ], +) +def test_fp8_kv_scale_compile(compilation_mode: int, model: str): model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 7f8e77a75621..3576efca591c 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -14,6 +14,7 @@ from typing import Literal, NamedTuple import pytest +import torch from vllm.config.model import RunnerOption from vllm.logger import init_logger @@ -254,6 +255,17 @@ def test_cp_generation( test_options: CPTestOptions, num_gpus_available, ): + if ( + model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat" + and torch.cuda.get_device_capability() < (9, 0) + ): + pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") + if ( + model_id == "bigcode/gpt_bigcode-santacoder" + and torch.cuda.get_device_capability() != (9, 0) + ): + pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") + _compare_cp_with_tp( model_id, parallel_setup, diff --git a/tests/entrypoints/sagemaker/__init__.py b/tests/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/sagemaker/conftest.py b/tests/entrypoints/sagemaker/conftest.py new file mode 100644 index 000000000000..4c859c2527d2 --- /dev/null +++ b/tests/entrypoints/sagemaker/conftest.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared fixtures and utilities for SageMaker tests.""" + +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# Model name constants used across tests +MODEL_NAME_ZEPHYR = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct" +LORA_ADAPTER_NAME_SMOLLM = "jekunz/smollm-135m-lora-fineweb-faroese" + +# SageMaker header constants +HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id" +HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id" +HEADER_SAGEMAKER_NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id" + + +@pytest.fixture(scope="session") +def smollm2_lora_files(): + """Download LoRA files once per test session.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id=LORA_ADAPTER_NAME_SMOLLM) + + +@pytest.fixture(scope="module") +def basic_server_with_lora(smollm2_lora_files): + """Basic server fixture with standard configuration.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--max-lora-rank", + "256", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=envs) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def async_client(basic_server_with_lora: RemoteOpenAIServer): + """Async OpenAI client fixture for use with basic_server.""" + async with basic_server_with_lora.get_async_client() as async_client: + yield async_client diff --git a/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py new file mode 100644 index 000000000000..0d4f8e885824 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration tests for handler override functionality. + +Tests real customer usage scenarios: +- Using @custom_ping_handler and @custom_invocation_handler decorators + to override handlers +- Setting environment variables for handler specifications +- Writing customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions +- Priority: env vars > decorators > customer script files > framework + defaults + +Note: These tests focus on validating server responses rather than directly calling +get_ping_handler() and get_invoke_handler() to ensure full integration testing. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestHandlerOverrideIntegration: + """Integration tests simulating real customer usage scenarios. + + Each test simulates a fresh server startup where customers: + - Use @custom_ping_handler and @custom_invocation_handler decorators + - Set environment variables (CUSTOM_FASTAPI_PING_HANDLER, etc.) + - Write customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions + """ + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + self._clear_env_vars() + + def teardown_method(self): + """Cleanup after each test.""" + self._clear_env_vars() + + def _clear_caches(self): + """Clear handler registry and function loader cache.""" + try: + from model_hosting_container_standards.common.handler import ( + handler_registry, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + handler_registry.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + def _clear_env_vars(self): + """Clear SageMaker environment variables.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + # Clear SageMaker env vars + for var in [ + SageMakerEnvVars.SAGEMAKER_MODEL_PATH, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME, + ]: + os.environ.pop(var, None) + + # Clear FastAPI env vars + for var in [ + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER, + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER, + ]: + os.environ.pop(var, None) + except ImportError: + pass + + @pytest.mark.asyncio + async def test_customer_script_functions_auto_loaded(self): + """Test customer scenario: script functions automatically override + framework defaults.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with ping() and invoke() functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "customer_override", + "message": "Custom ping from customer script" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Custom response from customer script"], + "source": "customer_override" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Customer sets SageMaker environment variables to point to their script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Customer tests their server and sees their overrides work + # automatically + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their functions are used + assert ping_data["source"] == "customer_override" + assert ping_data["message"] == "Custom ping from customer script" + assert invoke_data["source"] == "customer_override" + assert invoke_data["predictions"] == [ + "Custom response from customer script" + ] + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_customer_decorator_usage(self): + """Test customer scenario: using @custom_ping_handler and + @custom_invocation_handler decorators.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +@sagemaker_standards.custom_ping_handler +async def my_ping(): + return { + "type": "ping", + "source": "customer_decorator" + } + +@sagemaker_standards.custom_invocation_handler +async def my_invoke(request: Request): + return { + "type": "invoke", + "source": "customer_decorator" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their handlers are used by the server + assert ping_data["source"] == "customer_decorator" + assert invoke_data["source"] == "customer_decorator" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_handler_priority_order(self): + """Test priority: @custom_ping_handler/@custom_invocation_handler + decorators vs script functions.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script with both decorator and regular functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +# Customer uses @custom_ping_handler decorator (higher priority than script functions) +@sagemaker_standards.custom_ping_handler +async def decorated_ping(): + return { + "status": "healthy", + "source": "ping_decorator_in_script", + "priority": "decorator" + } + +# Customer also has a regular function (lower priority than +# @custom_ping_handler decorator) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_function", + "priority": "function" + } + +# Customer has a regular invoke function +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke_function", + "priority": "function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # @custom_ping_handler decorator has higher priority than + # script function + assert ping_data["source"] == "ping_decorator_in_script" + assert ping_data["priority"] == "decorator" + + # Script function is used for invoke + assert invoke_data["source"] == "script_invoke_function" + assert invoke_data["priority"] == "function" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_environment_variable_script_loading(self): + """Test that environment variables correctly specify script location + and loading.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script in a specific directory + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "env_loaded_script", + "method": "environment_variable_loading" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Loaded via environment variables"], + "source": "env_loaded_script", + "method": "environment_variable_loading" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Test environment variable script loading + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Verify that the script was loaded via environment variables + assert ping_data["source"] == "env_loaded_script" + assert ping_data["method"] == "environment_variable_loading" + assert invoke_data["source"] == "env_loaded_script" + assert invoke_data["method"] == "environment_variable_loading" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_framework_default_handlers(self): + """Test that framework default handlers work when no customer + overrides exist.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + # Explicitly pass empty env_dict to ensure no SageMaker env vars are set + # This prevents pollution from previous tests + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + env_dict = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: "", + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: "", + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: "", + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: "", + } + except ImportError: + env_dict = {} + + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=env_dict) as server: + # Test that default ping works + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + + # Test that default invocations work + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + + @pytest.mark.asyncio + async def test_handler_env_var_override(self): + """Test CUSTOM_FASTAPI_PING_HANDLER and CUSTOM_FASTAPI_INVOCATION_HANDLER + environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with both env var handlers and script functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request, Response +import json + +async def env_var_ping_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var_ping", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def env_var_invoke_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var_invoke", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_ping", + "method": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke", + "method": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to override both handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_var_ping_handler" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_var_invoke_handler" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler override + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable should override script function + assert ping_data["method"] == "environment_variable" + assert ping_data["source"] == "env_var_ping" + + # Test invocation handler override + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable should override script function + assert invoke_data["method"] == "environment_variable" + assert invoke_data["source"] == "env_var_invoke" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_env_var_priority_over_decorator_and_script(self): + """Test that environment variables have highest priority over decorators + and script functions for both ping and invocation handlers.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with all three handler types for both ping and invocation + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request, Response +import json + +# Environment variable handlers (highest priority) +async def env_priority_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +async def env_priority_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +# Decorator handlers (medium priority) +@sagemaker_standards.custom_ping_handler +async def decorator_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +@sagemaker_standards.custom_invocation_handler +async def decorator_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Decorator response"], + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +# Script functions (lowest priority) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script", + "priority": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script", + "priority": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to specify highest priority handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_priority_ping" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_priority_invoke" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler priority + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable has highest priority and should be used + assert ping_data["priority"] == "environment_variable" + assert ping_data["source"] == "env_var" + + # Test invocation handler priority + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable has highest priority and should be used + assert invoke_data["priority"] == "environment_variable" + assert invoke_data["source"] == "env_var" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py new file mode 100644 index 000000000000..a2867efdc584 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import openai # use the official async_client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import MODEL_NAME_SMOLLM + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # The SageMaker standards library creates a POST /adapters endpoint + # that maps to the load_lora_adapter handler with request shape: + # {"lora_name": "body.name", "lora_path": "body.src"} + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "smollm2-lora-sagemaker", "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + models = await async_client.models.list() + models = models.data + dynamic_lora_model = models[-1] + assert dynamic_lora_model.root == smollm2_lora_files + assert dynamic_lora_model.parent == MODEL_NAME_SMOLLM + assert dynamic_lora_model.id == "smollm2-lora-sagemaker" + + +@pytest.mark.asyncio +async def test_sagemaker_unload_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter + adapter_name = "smollm2-lora-sagemaker-unload" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify it's in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name in adapter_ids + + # Now unload it using DELETE /adapters/{adapter_name} + # The SageMaker standards maps this to unload_lora_adapter with: + # {"lora_name": "path_params.adapter_name"} + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify it's no longer in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name not in adapter_ids + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_not_found( + basic_server_with_lora: RemoteOpenAIServer, +): + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "nonexistent-adapter", "src": "/path/does/not/exist"}, + ) + assert load_response.status_code == 404 + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_invalid_files( + basic_server_with_lora: RemoteOpenAIServer, + tmp_path, +): + invalid_files = tmp_path / "invalid_adapter" + invalid_files.mkdir() + (invalid_files / "adapter_config.json").write_text("not valid json") + + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "invalid-adapter", "src": str(invalid_files)}, + ) + assert load_response.status_code == 400 + + +@pytest.mark.asyncio +async def test_sagemaker_unload_nonexistent_adapter( + basic_server_with_lora: RemoteOpenAIServer, +): + # Attempt to unload an adapter that doesn't exist + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", "nonexistent-adapter-name"), + ) + assert unload_response.status_code in (400, 404) + + +@pytest.mark.asyncio +async def test_sagemaker_invocations_with_adapter( + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter via SageMaker endpoint + adapter_name = "smollm2-lora-invoke-test" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Now test the /invocations endpoint with the adapter + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={ + "X-Amzn-SageMaker-Adapter-Identifier": adapter_name, + }, + json={ + "prompt": "Hello, how are you?", + "max_tokens": 10, + }, + ) + invocation_response.raise_for_status() + invocation_output = invocation_response.json() + + # Verify we got a valid completion response + assert "choices" in invocation_output + assert len(invocation_output["choices"]) > 0 + assert "text" in invocation_output["choices"][0] + + +@pytest.mark.asyncio +async def test_sagemaker_multiple_adapters_load_unload( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + adapter_names = [f"sagemaker-adapter-{i}" for i in range(5)] + + # Load all adapters + for adapter_name in adapter_names: + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify all are in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name in adapter_ids + + # Unload all adapters + for adapter_name in adapter_names: + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify all are removed from models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name not in adapter_ids diff --git a/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py new file mode 100644 index 000000000000..f1ed0c7e2897 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration test for middleware loader functionality. + +Tests that customer middlewares get called correctly with a vLLM server. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestMiddlewareIntegration: + """Integration test for middleware with vLLM server.""" + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + + def _clear_caches(self): + """Clear middleware registry and function loader cache.""" + try: + from model_hosting_container_standards.common.fastapi.middleware import ( + middleware_registry, + ) + from model_hosting_container_standards.common.fastapi.middleware.source.decorator_loader import ( # noqa: E501 + decorator_loader, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + middleware_registry.clear_middlewares() + decorator_loader.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + @pytest.mark.asyncio + async def test_customer_middleware_with_vllm_server(self): + """Test that customer middlewares work with actual vLLM server. + + Tests decorator-based middlewares (@custom_middleware, @input_formatter, + @output_formatter) + on multiple endpoints (chat/completions, invocations). + """ + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script with multiple decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware, input_formatter, output_formatter +) + +# Global flag to track if input formatter was called +_input_formatter_called = False + +@input_formatter +async def customer_input_formatter(request): + # Process input - mark that input formatter was called + global _input_formatter_called + _input_formatter_called = True + return request + +@custom_middleware("throttle") +async def customer_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Customer-Throttle"] = "applied" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "throttle," + return response + +@output_formatter +async def customer_output_formatter(response): + global _input_formatter_called + response.headers["X-Customer-Processed"] = "true" + # Since input_formatter and output_formatter are combined into + # pre_post_process middleware, + # if output_formatter is called, input_formatter should have been called too + if _input_formatter_called: + response.headers["X-Input-Formatter-Called"] = "true" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "output_formatter," + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to point to customer script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test 1: Middlewares applied to chat/completions endpoint + chat_response = requests.post( + server.url_for("v1/chat/completions"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert chat_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in chat_response.headers + assert chat_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in chat_response.headers + assert chat_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in chat_response.headers + assert chat_response.headers["X-Input-Formatter-Called"] == "true" + + # Verify middleware execution order + execution_order = chat_response.headers.get( + "X-Middleware-Order", "" + ).rstrip(",") + order_parts = execution_order.split(",") if execution_order else [] + assert "throttle" in order_parts + assert "output_formatter" in order_parts + + # Test 2: Middlewares applied to invocations endpoint + invocations_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert invocations_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in invocations_response.headers + assert invocations_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in invocations_response.headers + assert invocations_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in invocations_response.headers + assert ( + invocations_response.headers["X-Input-Formatter-Called"] == "true" + ) + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_with_ping_endpoint(self): + """Test that middlewares work with SageMaker ping endpoint.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware +) + +@custom_middleware("pre_post_process") +async def ping_tracking_middleware(request, call_next): + response = await call_next(request) + if request.url.path == "/ping": + response.headers["X-Ping-Tracked"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping endpoint with middleware + response = requests.get(server.url_for("ping")) + + assert response.status_code == 200 + assert "X-Ping-Tracked" in response.headers + assert response.headers["X-Ping-Tracked"] == "true" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_env_var_override(self): + """Test middleware environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with middleware functions specified via env vars + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +# Global flag to track if pre_process was called +_pre_process_called = False + +async def env_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Env-Throttle"] = "applied" + return response + +async def env_pre_process(request: Request) -> Request: + # Mark that pre_process was called + global _pre_process_called + _pre_process_called = True + return request + +async def env_post_process(response): + global _pre_process_called + if hasattr(response, 'headers'): + response.headers["X-Env-Post-Process"] = "applied" + # Since pre_process and post_process are combined into + # pre_post_process middleware, + # if post_process is called, pre_process should have been called too + if _pre_process_called: + response.headers["X-Pre-Process-Called"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables for middleware + # Use script_name with .py extension as per plugin example + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_MIDDLEWARE_THROTTLE: ( + f"{script_name}:env_throttle_middleware" + ), + FastAPIEnvVars.CUSTOM_PRE_PROCESS: f"{script_name}:env_pre_process", + FastAPIEnvVars.CUSTOM_POST_PROCESS: f"{script_name}:env_post_process", + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + response = requests.get(server.url_for("ping")) + assert response.status_code == 200 + + # Check if environment variable middleware was applied + headers = response.headers + + # Verify that env var middlewares were applied + assert "X-Env-Throttle" in headers, ( + "Throttle middleware should be applied via env var" + ) + assert headers["X-Env-Throttle"] == "applied" + + assert "X-Env-Post-Process" in headers, ( + "Post-process middleware should be applied via env var" + ) + assert headers["X-Env-Post-Process"] == "applied" + + # Verify that pre_process was called + assert "X-Pre-Process-Called" in headers, ( + "Pre-process should be called via env var" + ) + assert headers["X-Pre-Process-Called"] == "true" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py new file mode 100644 index 000000000000..6206000385bd --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import openai # use the official client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + HEADER_SAGEMAKER_CLOSED_SESSION_ID, + HEADER_SAGEMAKER_NEW_SESSION_ID, + HEADER_SAGEMAKER_SESSION_ID, + MODEL_NAME_SMOLLM, +) + +CLOSE_BADREQUEST_CASES = [ + ( + "nonexistent_session_id", + {"session_id": "nonexistent-session-id"}, + {}, + "session not found", + ), + ("malformed_close_request", {}, {"extra-field": "extra-field-data"}, None), +] + + +@pytest.mark.asyncio +async def test_create_session_badrequest(basic_server_with_lora: RemoteOpenAIServer): + bad_response = requests.post( + basic_server_with_lora.url_for("invocations"), + json={"requestType": "NEW_SESSION", "extra-field": "extra-field-data"}, + ) + + assert bad_response.status_code == 400 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_name,session_id_change,request_body_change,expected_error", + CLOSE_BADREQUEST_CASES, +) +async def test_close_session_badrequest( + basic_server_with_lora: RemoteOpenAIServer, + test_name: str, + session_id_change: dict[str, str], + request_body_change: dict[str, str], + expected_error: str | None, +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + if request_body_change: + close_request_json.update(request_body_change) + bad_session_id = session_id_change.get("session_id") + bad_close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: bad_session_id or valid_session_id}, + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert bad_close_response.status_code == 400 + if expected_error: + assert expected_error in bad_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_close_session_invalidrequest( + basic_server_with_lora: RemoteOpenAIServer, async_client: openai.AsyncOpenAI +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + invalid_close_response = requests.post( + url, + # no headers to specify session_id + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert invalid_close_response.status_code == 424 + assert "invalid session_id" in invalid_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_session(basic_server_with_lora: RemoteOpenAIServer): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + # test invocation with session id + + request_args = { + "model": MODEL_NAME_SMOLLM, + "prompt": "what is 1+1?", + "max_completion_tokens": 5, + "temperature": 0.0, + "logprobs": False, + } + + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json=request_args, + ) + invocation_response.raise_for_status() + + # close created session, should succeed + close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + close_response.raise_for_status() + + assert ( + close_response.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID) + == valid_session_id + ) diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml index ea9c95158405..9297bf6ddf2d 100644 --- a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -3,6 +3,3 @@ accuracy_threshold: 0.45 num_questions: 1319 num_fewshot: 5 max_model_len: 4096 -# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220 -env: - VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1" diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 18995545552e..6e5468969bf2 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -9,7 +9,6 @@ from vllm.vllm_flash_attn import ( fa_version_unsupported_reason, flash_attn_varlen_func, - flash_attn_with_kvcache, is_fa_version_supported, ) @@ -83,124 +82,6 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) -@pytest.mark.parametrize("fa_version", [2, 3]) -@pytest.mark.parametrize("q_dtype", QDTYPES) -@torch.inference_mode() -def test_flash_attn_with_paged_kv( - use_out: bool, - kv_lens: list[int], - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: float | None, - num_blocks: int, - sliding_window: int | None, - fa_version: int, - q_dtype: torch.dtype | None, -) -> None: - torch.set_default_device("cuda") - if not is_fa_version_supported(fa_version): - pytest.skip( - f"Flash attention version {fa_version} not supported due " - f'to: "{fa_version_unsupported_reason(fa_version)}"' - ) - if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip( - "Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type" - ) - - current_platform.seed_everything(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn( - num_blocks, block_size, num_kv_heads, head_size, dtype=dtype - ) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint( - 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 - ) - - q = query.unsqueeze(1) - out = torch.empty_like(q) if use_out else None - - maybe_quantized_query = q - maybe_quantized_key_cache = key_cache - maybe_quantized_value_cache = value_cache - q_descale = None - k_descale = None - v_descale = None - if q_dtype is not None: - # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = q.to(q_dtype) - maybe_quantized_key_cache = key_cache.to(q_dtype) - maybe_quantized_value_cache = value_cache.to(q_dtype) - - scale_shape = (num_seqs, num_kv_heads) - q_descale = torch.ones(scale_shape, dtype=torch.float32) - k_descale = torch.ones(scale_shape, dtype=torch.float32) - v_descale = torch.ones(scale_shape, dtype=torch.float32) - - output = flash_attn_with_kvcache( - q=maybe_quantized_query, - k_cache=maybe_quantized_key_cache, - v_cache=maybe_quantized_value_cache, - out=out, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - window_size=window_size, - fa_version=fa_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - output = output if not use_out else out - output = output.squeeze(1) - - atol, rtol = 1.5e-2, 1e-2 - if q_dtype is not None: - atol, rtol = 1.5e-1, 1.5e-1 - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window, - ) - ( - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), - f"{torch.max(torch.abs(output - ref_output))}", - ) - - @pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize( "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 014df1fa111f..c27cf2468ede 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,6 +6,8 @@ """ import functools +import importlib +import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -20,6 +22,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context @@ -412,14 +415,12 @@ def test_mixtral_moe( huggingface.""" # clear the cache before every test - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) + # Force reload aiter_ops to pick up the new environment variables. + if "rocm_aiter_ops" in sys.modules: + importlib.reload(rocm_aiter_ops) - is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 97a55c37b9a3..420dbbffaac0 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -25,6 +25,7 @@ (8, 16, 128 * 2, fp8_dtype), (8, 16, 128 * 3, fp8_dtype), (8, 64, 7168, fp8_dtype), + (8, 128, 128 * 33, fp8_dtype), (8, 128, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype), @@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): ) # Run the SiLU V2 kernel + # TODO (varun): use_e8m0 is set to false as the reference impl does + # not handle that case. y_q, y_s = persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, group_size=group_size + y, tokens_per_expert, group_size=group_size, use_ue8m0=False ) torch.cuda.synchronize() diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 41419553aa83..9121284de85b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import ( @@ -15,9 +16,6 @@ dispatch_topk_func, vllm_topk_softmax, ) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_rocm_rmsnorm_func, @@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - topk_func = dispatch_topk_func() - is_rocm_aiter_moe_enabled.cache_clear() - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax, - ) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_topk_dispatch(use_rocm_aiter: bool): + topk_func = dispatch_topk_func(use_rocm_aiter) - assert topk_func == rocm_aiter_topk_softmax + if current_platform.is_rocm() and use_rocm_aiter: + assert topk_func == rocm_aiter_ops.topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.skipif( not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" ) def test_rms_norm_dispatch( - add_residual: bool, - dtype: torch.dtype, - use_rocm_aiter: str, - use_rocm_aiter_norm: str, - monkeypatch, + add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool ): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) should_use_rocm_aiter = ( current_platform.is_rocm() - and int(use_rocm_aiter) - and int(use_rocm_aiter_norm) + and use_rocm_aiter and dtype in RMS_NORM_SUPPORTED_DTYPES ) if add_residual and should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add elif should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + assert rms_norm_func == rocm_aiter_ops.rms_norm elif add_residual: assert rms_norm_func == fused_add_rms_norm else: diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 87f5d40ac1da..c9d227599cde 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -4,7 +4,7 @@ import pytest from vllm import SamplingParams -from vllm.logprobs import FlattenLogprobs +from vllm.logprobs import FlatLogprobs MODELS = ["distilbert/distilgpt2"] MAX_TOKENS = 5 @@ -16,17 +16,17 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("greedy", [True, False]) -@pytest.mark.parametrize("flatten_logprobs", [True, False]) +@pytest.mark.parametrize("flat_logprobs", [True, False]) def test_ranks( vllm_runner, model, dtype, greedy, - flatten_logprobs, + flat_logprobs, example_prompts, monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0") with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] @@ -44,12 +44,8 @@ def test_ranks( decode_tokens, _, decode_logprobs, prompt_logprobs = result # Ensure the return type of logprobs is accurate - assert isinstance( - prompt_logprobs, FlattenLogprobs if flatten_logprobs else list - ) - assert isinstance( - decode_logprobs, FlattenLogprobs if flatten_logprobs else list - ) + assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list) + assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list) ######################## # Check prompt logprobs diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py index 1799d3638178..d26a460d2bca 100644 --- a/tests/test_logprobs.py +++ b/tests/test_logprobs.py @@ -5,7 +5,7 @@ import pytest from vllm.logprobs import ( - FlattenLogprobs, + FlatLogprobs, Logprob, LogprobsOnePosition, append_logprobs_for_next_position, @@ -14,8 +14,8 @@ ) -def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") +def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") prompt_logprobs = create_prompt_logprobs() assert isinstance(prompt_logprobs, list) @@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") +def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") prompt_logprobs = create_prompt_logprobs() - assert isinstance(prompt_logprobs, FlattenLogprobs) + assert isinstance(prompt_logprobs, FlatLogprobs) assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.end_indices == [0] assert len(prompt_logprobs.token_ids) == 0 @@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert prompt_logprobs[0] == dict() sample_logprobs = create_sample_logprobs() - assert isinstance(sample_logprobs, FlattenLogprobs) + assert isinstance(sample_logprobs, FlatLogprobs) assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.end_indices) == 0 assert len(sample_logprobs.token_ids) == 0 @@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_append_logprobs_for_next_position_none_flatten( +def test_append_logprobs_for_next_position_none_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten( ] -def test_append_logprobs_for_next_position_flatten( +def test_append_logprobs_for_next_position_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten( rank=11, num_logprobs=-1, ) - assert isinstance(logprobs, FlattenLogprobs) + assert isinstance(logprobs, FlatLogprobs) assert logprobs.start_indices == [0, 1] assert logprobs.end_indices == [1, 3] assert logprobs.token_ids == [1, 2, 3] @@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten( } -def test_flatten_logprobs_append() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_append() -> None: + logprobs = FlatLogprobs() logprobs.append(LOGPROBS_ONE_POSITION_0) logprobs.append(LOGPROBS_ONE_POSITION_1) assert logprobs.start_indices == [0, 1] @@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None: assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"] -def test_flatten_logprobs_extend() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_extend() -> None: + logprobs = FlatLogprobs() # Extend with list[LogprobsOnePosition] logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) assert logprobs.start_indices == [0, 3] @@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.ranks == [40, 50, 60, 10] assert logprobs.decoded_tokens == ["40", "50", "60", "10"] - other_logprobs = FlattenLogprobs() + other_logprobs = FlatLogprobs() other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) - # Extend with another FlattenLogprobs + # Extend with another FlatLogprobs logprobs.extend(other_logprobs) assert logprobs.start_indices == [0, 3, 4, 6] assert logprobs.end_indices == [3, 4, 6, 7] @@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"] -def test_flatten_logprobs_access() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_access() -> None: + logprobs = FlatLogprobs() logprobs.extend( [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] ) diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py new file mode 100644 index 000000000000..771076186a3b --- /dev/null +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import AutoTokenizer + +from vllm.config import StructuredOutputsConfig, VllmConfig +from vllm.config.model import ModelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.backend_guidance import GuidanceBackend +from vllm.v1.structured_output.backend_types import StructuredOutputOptions + +TOKENIZER = "gpt2" + + +def test_backend_guidance_rollback_terminated(): + # Test that the backend guidance successfully rollbacks from a + # terminated state. This can happen with speculative decoding, + # where the draft model proposes EOS and it is verified by the + # guidance backend. In that case we are in a stopped state, but + # it should be reverted in case EOS is not accepted by the target + # model. + vllm_config = VllmConfig( + decoding_config=StructuredOutputsConfig( + backend="guidance", + ) + ) + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + backend = GuidanceBackend( + vllm_config, + tokenizer=tokenizer, + vocab_size=50257, + ) + + grammar = backend.compile_grammar( + StructuredOutputOptions.JSON, '{"type": "object"}' + ) + + prompt = tokenizer.encode('{"a": "b"}') + assert len(prompt) > 1 + dummy_wrong = tokenizer.encode('{"a"}') + for token in prompt: + assert grammar.accept_tokens("", [token]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Giving any other token should also be accepted + assert grammar.accept_tokens("", dummy_wrong) + # Rollback is done from where state was terminated, so from '}' not EOS + grammar.rollback(len(prompt) - 1) + assert not grammar.is_terminated() + assert grammar.validate_tokens([tokenizer.eos_token_id]) == [] + assert grammar.validate_tokens(dummy_wrong) != dummy_wrong + assert grammar.accept_tokens("", prompt[1:]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Rollback of <= 0 should not change the terminated state + grammar.rollback(0) + assert grammar.is_terminated() + grammar.rollback(-1) + assert grammar.is_terminated() + + +def test_grammar_bitmask_with_specdec(): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + prompt = tokenizer.encode('{"a": "b"}') + vllm_config = VllmConfig( + model_config=ModelConfig(tokenizer=TOKENIZER), + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3), + ) + structured_output_manager = StructuredOutputManager(vllm_config) + + for i in range(1, 2): + sampling_params = SamplingParams( + structured_outputs=StructuredOutputsParams( + json='{"type": "object"}', + ), + ) + sampling_params.structured_outputs._backend = "guidance" + + my_req_id = f"my_req_id_{i}" + request = Request( + my_req_id, + prompt_token_ids=prompt[:i], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=tokenizer.eos_token_id, + ) + + structured_output_manager.grammar_init(request) + + def grammar_bitmask(req: Request, tokens: list[int]) -> None: + structured_output_manager.grammar_bitmask( + requests={req.request_id: req}, + structured_output_request_ids={req.request_id: 0}, + scheduled_spec_decode_tokens={req.request_id: tokens}, + ) + # At this point, we rolled-back, so should not be terminated + assert not req.structured_output_request.grammar.is_terminated() + + # The grammar might not yet be compiled, so we wait for it + while not request.structured_output_request._check_grammar_completion(): + continue + + assert request.structured_output_request.grammar.accept_tokens( + request.request_id, prompt[:i] + ) + + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) + grammar_bitmask( + request, prompt[i:] + [tokenizer.eos_token_id] + prompt + ) # EOS not the final token + grammar_bitmask(request, prompt[i:]) # EOS not present + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py new file mode 100644 index 000000000000..8d35aa65738b --- /dev/null +++ b/vllm/_aiter_ops.py @@ -0,0 +1,942 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer + + +def is_aiter_found() -> bool: + from importlib.util import find_spec + + return find_spec("aiter") is not None + + +# `find_spec` is not torch.compile compatible. +# In cases where aiter availability might have +# been checked in forward passes that are torch compiled. +# we keep this global outside to not cause torch compile breaks. +IS_AITER_FOUND = is_aiter_found() + + +def if_aiter_supported(func: Callable) -> Callable: + """Decorator that only executes the function if + ROCm AITER package is supported on gfx9 archs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # checks the platform, device arch and aiter library existance. + + from vllm.platforms.rocm import on_gfx9 + + if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND: + return func(*args, **kwargs) + else: + # Return None or do nothing if not supported + return None + + return wrapper + + +def _rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def _rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = ActivationType(activation_method) + + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) + + +def _rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_topk_softmax_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + from aiter import topk_softmax + + topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + + +def _rocm_aiter_topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + pass + + +def _rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + from aiter import biased_grouped_topk + + biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + +def _rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_grouped_topk_impl( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + is_softmax = scoring_func == "softmax" + from aiter import grouped_topk + + grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + is_softmax, + routed_scaling_factor, + ) + + +def _rocm_aiter_grouped_topk_fake( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + +def _rocm_aiter_mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +def _rocm_aiter_gemm_w8a8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) + + +def _rocm_aiter_gemm_w8a8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + +def _rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + from aiter import rms_norm + + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rms_norm(x, weight, variance_epsilon) + + +def _rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + return torch.empty_like(x) + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import rmsnorm2d_fwd_with_add + + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + rmsnorm2d_fwd_with_add( + output, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return output, residual_out + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +# Global flag to ensure ops are registered only once +_OPS_REGISTERED = False + + +class rocm_aiter_ops: + _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + + @classmethod + @if_aiter_supported + def is_enabled(cls) -> bool: + """Verifies device specs and availability of aiter main env variable.""" + return cls._AITER_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._LINEAR_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_fp8_enaled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + + @classmethod + @if_aiter_supported + def is_rmsnorm_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._RMSNORM_ENABLED + + @classmethod + @if_aiter_supported + def is_fused_moe_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FMOE_ENABLED + + @classmethod + @if_aiter_supported + def is_fusion_moe_shared_experts_enabled(cls) -> bool: + return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED + + @classmethod + @if_aiter_supported + def is_mla_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MLA_ENABLED + + @classmethod + @if_aiter_supported + def is_mha_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MHA_ENABLED + + @classmethod + @if_aiter_supported + def is_pa_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_triton_unified_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_fp8bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + + @classmethod + @if_aiter_supported + def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM + + @classmethod + @if_aiter_supported + def is_triton_rotary_embed_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED + + @staticmethod + @if_aiter_supported + def register_ops_once() -> None: + global _OPS_REGISTERED + if not _OPS_REGISTERED: + tags = ( + tuple() + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ) + + # register all the custom ops here + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=_rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=_rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fused_moe", + op_func=_rocm_aiter_fused_moe_impl, + mutates_args=[], + fake_impl=_rocm_aiter_fused_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=_rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=_rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_biased_grouped_topk", + op_func=_rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_biased_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_grouped_topk", + op_func=_rocm_aiter_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_mla_decode_fwd", + op_func=_rocm_aiter_mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=_rocm_aiter_mla_decode_fwd_fake, + tags=tags, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8", + op_func=_rocm_aiter_gemm_w8a8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=_rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=_rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + _OPS_REGISTERED = True + + @staticmethod + def rms_norm2d_with_add( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( + x, residual, weight, variance_epsilon + ) + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + + @staticmethod + def gemm_w8a8( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype) + + @staticmethod + def gemm_w8a8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + A, B, As, Bs, output_dtype + ) + + @staticmethod + def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation_method, + quant_method, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + @staticmethod + def asm_moe_tkw1( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale, + fc2_scale, + fc1_smooth_scale, + fc2_smooth_scale, + a16, + per_tensor_quant_scale, + expert_mask, + activation_method, + ) + + @staticmethod + def topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + ) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + return topk_weights, topk_indices + + @staticmethod + def biased_grouped_topk( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + @staticmethod + def grouped_topk( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + scoring_func, + routed_scaling_factor, + ) + + @staticmethod + def mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + logit_cap: float = 0.0, + ): + torch.ops.vllm.rocm_aiter_mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + @staticmethod + def triton_fp4_gemm_dynamic_qaunt( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype | None = torch.bfloat16, + x_scales: torch.Tensor | None = None, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + if x_scales is None: + x_q, x_s = dynamic_mxfp4_quant(x) + else: + x_q = x + x_s = x_scales + + y = torch.empty( + x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) + return y + + @staticmethod + def triton_rotary_embed( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, + is_neox_style: bool, + ): + from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace + + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox_style else 1 + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + rope_cached_thd_positions_2c_fwd_inplace( + positions, + sin, + cos, + query_, + key_, + rotate_style, + reuse_freqs_front_part=True, + is_nope_first=False, + ) + query = query.view(query_shape) + key = key.view(key_shape) + + @staticmethod + def triton_fp8_bmm( + X: torch.Tensor, + WQ: torch.Tensor, + w_scale: torch.Tensor, + group_size: int = 128, + bias: torch.Tensor | None = None, + dtype: torch.dtype | None = torch.bfloat16, + splitK: int | None = None, + YQ: torch.Tensor | None = None, + transpose_bm: bool | None = False, + config: dict | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, + ) + + return aiter_triton_fp8_bmm( + X, + WQ, + w_scale, + group_size=group_size, + bias=bias, + dtype=dtype, + splitK=splitK, + YQ=YQ, + transpose_bm=transpose_bm, + config=config, + ) + + @staticmethod + def triton_gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + @staticmethod + def per_1x128_fp8_quant( + input_2d: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Only applies quantization method for fp8 data type only.""" + from aiter import QuantType, dtypes, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8) + + @staticmethod + def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: + return (n, k) in [ + (1024, 8192), + (2112, 7168), + (3072, 1536), + (32768, 8192), + (4096, 7168), + (4608, 7168), + (512, 7168), + (7168, 2048), + (7168, 256), + (8192, 1024), + (8192, 32768), + ] + + @staticmethod + def shuffle_weight( + self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> torch.Tensor: + from aiter.ops.shuffle import shuffle_weight + + return shuffle_weight(tensor, layout=layout) + + @staticmethod + def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. + + Args: + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). + + Returns: + A Tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) + + +if IS_AITER_FOUND: + rocm_aiter_ops.register_ops_once() diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..acab0529f352 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -745,6 +745,9 @@ def forward( k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -752,12 +755,6 @@ def forward( attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # Mirror Attention.forward scale calculation path - if self.calculate_kv_scales and getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): - self.calc_kv_scales(q, kv_c_normed, k_pe) - if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.impl.forward( @@ -786,14 +783,6 @@ def forward( ) return output else: - # We can still access forward context to check calculation flag - if self.calculate_kv_scales: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - if getattr(attn_metadata, "enable_kv_scales_calculation", False): - self.calc_kv_scales(q, kv_c_normed, k_pe) return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, @@ -848,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str): return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -865,6 +856,8 @@ def maybe_save_kv_layer_to_connector( return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -881,17 +874,13 @@ def maybe_calc_kv_scales( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] - if attn_metadata is None or not getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): + # Only calculate if the layer's calculate_kv_scales flag is True + # This flag gets set to False after the first forward pass + if not self.calculate_kv_scales: return - self = forward_context.no_compile_layers[layer_name] self.calc_kv_scales(query, key, value) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 75fdcb8f48b2..2cbb5c91cc3b 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -195,7 +195,6 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) - assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) if return_lse: diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py deleted file mode 100644 index 6308f63cc4e7..000000000000 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -import torch - -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer - - -def get_aiter_mla_metadata( - max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device -) -> tuple[torch.Tensor, ...]: - paged_kv_indices = torch.zeros( - max_batch_size * max_block_per_batch, dtype=torch.int32, device=device - ) - paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device) - paged_kv_last_page_lens = torch.full( - (max_batch_size,), block_size, dtype=torch.int32 - ) - qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr - - -def aiter_mla_decode_fwd( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - sm_scale: float, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - logit_cap: float = 0.0, -): - torch.ops.vllm.rocm_aiter_mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - max_seqlen_qo, - kv_indptr, - kv_indices, - kv_last_page_lens, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_impl( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - from aiter.mla import mla_decode_fwd - - mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_fake( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - pass - - -if current_platform.is_rocm(): - if is_torch_equal_or_newer("2.7.0"): - tags = () - else: - tags = ((torch.Tag.needs_fixed_stride_order,),) - direct_register_custom_op( - op_name="rocm_aiter_mla_decode_fwd", - op_func=mla_decode_fwd_impl, - mutates_args=["o"], - fake_impl=mla_decode_fwd_fake, - tags=tags, - ) diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index e58cf5911282..0e9b0fbe2c02 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -19,7 +19,6 @@ import argparse import asyncio import contextlib -import gc import importlib.util import json import os @@ -49,6 +48,7 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.gc_utils import freeze_gc_heap MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1414,8 +1414,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: percentile_metrics: str = args.percentile_metrics or default_percentile_metrics # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() + freeze_gc_heap() benchmark_result = await benchmark( task_type=task_type, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index be69075f94f0..efd68a71c7e4 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -22,6 +22,7 @@ should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.utils import Range from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -83,7 +84,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[Range | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -92,11 +93,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: Range | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: with inductor_partition_rule_context( self.compilation_config.splitting_ops @@ -152,26 +153,28 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range | None = None, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + handle = self.cache[(compile_range, graph_index, self.compiler.name)] compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape + handle, graph, example_inputs, graph_index, compile_range ) - if runtime_shape is None: + if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", + "Directly load the %s-th graph for compile range %s" + "from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -185,7 +188,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, + compile_range: Range | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -197,7 +200,7 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -205,7 +208,7 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " "from the cache, took %.3f s", @@ -213,9 +216,9 @@ def compile( ) else: logger.info( - "Directly load the compiled graph(s) for shape %s " + "Directly load the compiled graph(s) for compile range %s " "from the cache, took %.3f s", - str(runtime_shape), + str(compile_range), elapsed, ) return compiled_graph @@ -226,14 +229,18 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): + maybe_key = "artifact_compile_range_" + if compile_range is None: + maybe_key += "dynamic_shape" + else: + maybe_key += f"{compile_range.start}_{compile_range.end}" + maybe_key += f"_subgraph_{graph_index}" + with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, example_inputs, additional_inductor_config, - runtime_shape, + compile_range, maybe_key, ) @@ -241,33 +248,33 @@ def compile( # store the artifact in the cache if is_compile_cache_enabled(additional_inductor_config) and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: + if compile_range is None: logger.info_once( "Cache the graph for dynamic shape for later use", scope="local" ) else: logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), - scope="local", + "Cache the graph of compile range %s for later use", + str(compile_range), ) - if runtime_shape is None: + if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", + "Store the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", + "Store the %s-th graph for compile range%s from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -277,16 +284,16 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) else: logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), elapsed, scope="local", ) @@ -405,19 +412,7 @@ def call_module( sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -427,7 +422,6 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 69d4606d73eb..32d1f1531f4c 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -431,7 +432,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -441,7 +442,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range.is_single_size() and compile_range.end % tp_size == 0 + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -505,91 +508,60 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " + f"element size {element_size}" + ) + device_capability = current_platform.get_device_capability().to_int() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB + ) - if num_tokens <= max_token_num: - device_capability = current_platform.get_device_capability().to_int() - # Get one shot input size limit for the current world size - # for the current device capability - max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} - ).get(world_size, None) - # Use one shot if no max size for one shot is specified - use_oneshot = ( - max_one_shot_size_mb is None - or current_tensor_size <= max_one_shot_size_mb * MiB - ) - - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None and scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -1128,7 +1100,8 @@ def __init__(self, config: VllmConfig): if max_size is None: # Flashinfer doesn't support current world size logger.warning( - "Flashinfer allreduce fusion is not supported for world size %s", + "Flashinfer allreduce fusion is not supported for world size %s" + " or max size is not provided", self.tp_size, ) return @@ -1216,6 +1189,11 @@ def register_patterns(self): self.disabled = False + def is_applicable_for_range(self, compile_range: Range | None) -> bool: + if compile_range is None: + return False + return compile_range.end - 1 <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b0cdb08884a3..b95067aba191 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -63,16 +64,17 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means + with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + `compile_range` specifies the range of the inputs, + it could be concrete size, e.g. (4, 4). + Right now we only support one variable range of shapes for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +100,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -212,18 +214,21 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): - dynamic_shapes = "from_example_inputs" + if compile_range is not None: + if compile_range.is_single_size(): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_graph" else: dynamic_shapes = "from_tracing_context" @@ -235,7 +240,6 @@ def compile( dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}, ) - # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) @@ -251,7 +255,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -315,7 +319,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -329,7 +333,7 @@ def compile( current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -512,7 +516,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -608,9 +612,9 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range): + if compile_range is not None and compile_range.is_single_size(): + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -630,7 +634,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..008eba4629a3 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,6 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): @@ -28,8 +29,8 @@ class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: Range | None): + self.compile_range: Range | None = compile_range def get_pass_context() -> PassContext: @@ -39,13 +40,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: Range | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +97,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: Range | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index dfda2adf1d3b..820fa9b007e3 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -69,13 +69,13 @@ def __init__(self): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph @@ -127,5 +127,10 @@ def uuid(self): for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) + compile_range = get_pass_context().compile_range + if compile_range is not None: + # Include the compile range in the uuid to ensure that inductor + # recompiles the graph for the new dynamic compile range. + state["compile_range"] = str(compile_range) return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..8f34aa818a80 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,18 +7,18 @@ import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import init_logger logger = init_logger(__name__) @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: Range compiled: bool = False runnable: Callable = None # type: ignore @@ -31,7 +31,6 @@ def __init__( piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -55,67 +54,124 @@ def __init__( self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) - self.first_run_finished = False + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + self.first_run_finished = False self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + # self.concrete_size_entries: dict[int, RangeEntry] = {} - # to_be_compiled_sizes tracks the remaining sizes to compile, + # the entries for ranges that we need to either + self.range_entries: dict[Range, RangeEntry] = {} + + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - runnable=self.compiled_graph_for_general_shape, + for size in self.compile_sizes: + range = Range(start=size, end=size) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + self.to_be_compiled_ranges.add(range) + + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] + def fakify_args(self, args: list[Any]) -> list[Any]: + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # torch.compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in self.graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(args) + return fake_example_inputs + + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + # fakify for range, real args for concrete size + args = ( + self.fakify_args(args) + if not range_entry.compile_range.is_single_size() + else args + ) + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, + compile_range=range_entry.compile_range, ) # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() + self.check_for_ending_compilation() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + + # Role of the general graph is taken by the last range graph + range_entry = self.range_entries[self.compile_ranges[-1]] + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) + runtime_shape = args[self.sym_shape_indices[0]] - return entry.runnable(*args) + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. + range_found = False + if runtime_shape in self.compile_sizes: + range_entry = self.range_entries[ + Range(start=runtime_shape, end=runtime_shape) + ] + range_found = True + else: + for range in self.compile_ranges: + if runtime_shape in range: + range_entry = self.range_entries[range] + range_found = True + break + assert range_found, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) + + self._maybe_compile_for_range_entry(range_entry, args) + + return range_entry.runnable(*args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..6a5ee5a0efb7 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -482,7 +483,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -502,7 +503,11 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return ( + compile_range is not None + and (compile_range.is_single_size()) + and (compile_range.end % tp_size == 0) + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 92cf16f259fe..36bbd2b9abff 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -14,7 +14,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config +from vllm.config.utils import Range, config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -140,6 +140,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: """ MiB = 1024 * 1024 + FI_SUPPORTED_WORLD_SIZES = [2, 4, 8] + if world_size not in FI_SUPPORTED_WORLD_SIZES: + return None max_size_mb = self.fi_allreduce_fusion_max_size_mb if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) @@ -212,6 +215,8 @@ class CompilationConfig: - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -341,6 +346,16 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: list[int] | None = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]), + [split_points[0], split_points[1]), ..., + [split_points[-1], max_num_batched_tokens + 1). + Compile sizes are also used single element ranges: + [compile_sizes[i], compile_sizes[i] + 1). + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -938,3 +953,16 @@ def custom_op_log_check(self): enable_str, op, ) + + def get_compile_ranges(self) -> list[Range]: + """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] + split_points = sorted(set(self.compile_ranges_split_points)) + compile_ranges = [] + for i, s in enumerate(split_points): + if i == 0: + compile_ranges.append(Range(start=1, end=s)) + else: + compile_ranges.append(Range(start=split_points[i - 1], end=s)) + return compile_ranges diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 7e0878d96bbd..ea97ddf125f7 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -6,7 +6,7 @@ import inspect import textwrap from collections.abc import Iterable -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -176,3 +176,37 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: ) processed_overrides[field_name] = value return replace(config, **processed_overrides) + + +@dataclass +class Range: + """ + A range of numbers. + Inclusive of start, exclusive of end. + """ + + start: int + end: int + + def is_single_size(self) -> bool: + return self.start == self.end + + def __contains__(self, size: int) -> bool: + # Inclusive of start, exclusive of end + if self.is_single_size(): + return size == self.start + return self.start <= size < self.end + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Range): + return False + return self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.start, self.end)) + + def __str__(self) -> str: + return f"(start={self.start}, end={self.end})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d4ee6f980e6e..4557e59a5cf8 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -526,6 +526,8 @@ def __post_init__(self): "correctness and to realize prefill savings. " ) + self._set_compile_ranges() + disable_chunked_prefill_reasons: list[str] = [] if self.model_config: @@ -608,17 +610,19 @@ def __post_init__(self): ) current_platform.check_and_update_config(self) - assert ( - self.parallel_config.dcp_kv_cache_interleave_size - <= self.cache_config.block_size - and self.cache_config.block_size - % self.parallel_config.dcp_kv_cache_interleave_size - == 0 - ), ( - f"Block_size({self.cache_config.block_size}) should be " - "greater than or equal to and divisible by dcp_kv_cache_interleave_size " - f"({self.parallel_config.dcp_kv_cache_interleave_size})." - ) + # If DCP, ensure the block size is right. + if self.parallel_config.decode_context_parallel_size > 1: + assert ( + self.parallel_config.dcp_kv_cache_interleave_size + <= self.cache_config.block_size + and self.cache_config.block_size + % self.parallel_config.dcp_kv_cache_interleave_size + == 0 + ), ( + f"Block_size({self.cache_config.block_size}) should be greater " + "than or equal to and divisible by dcp_kv_cache_interleave_size " + f"({self.parallel_config.dcp_kv_cache_interleave_size})." + ) assert ( self.parallel_config.dcp_kv_cache_interleave_size == 1 @@ -924,6 +928,53 @@ def _set_cudagraph_sizes(self): # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + # We add 1 because the bounds checks in the compiler are exclusive + # and we want to include the max_num_batched_tokens + # in the compile range + computed_compile_ranges_split_points.append(max_num_batched_tokens + 1) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.enable_fi_allreduce_fusion: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + # We add 1 because the bounds checks in the compiler are + # exclusive and we want to include the max_token_num in the + # compile range + if ( + max_num_batched_tokens is not None + and max_token_num < max_num_batched_tokens + ): + computed_compile_ranges_split_points.append(max_token_num + 1) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) # type: ignore + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index c9d30d6481ab..e6645e524cc3 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -12,6 +12,7 @@ on how the EPLB algorithm works. """ +import numpy as np import torch @@ -34,29 +35,44 @@ def balanced_packing( assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs + device = weight.device + if groups_per_pack == 1: pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=weight.device + weight.size(-1), dtype=torch.int64, device=device ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") - rank_in_pack = torch.full_like(pack_index, fill_value=-1) + weight_np = weight.cpu().numpy() + + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) + + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm for i in range(num_layers): - pack_weights = [0] * num_packs + pack_weights = [0.0] * num_packs pack_items = [0] * num_packs - for group in indices[i]: + + for group in indices_np[i]: + # Find a pack with capacity that has the lowest weight pack = min( - (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, ) + assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 + + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + return pack_index, rank_in_pack @@ -212,7 +228,7 @@ def rebalance_experts( replicas for each logical expert """ num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() + weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8ec3e956401..5c1efbaf03ba 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace( ) return + old_global_expert_indices_cpu = old_global_expert_indices.cpu() + new_global_expert_indices_cpu = new_global_expert_indices.cpu() + + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + for layer in range(num_moe_layers): - # NOTE(bowen): We need this synchronize to run, but I don't know why. - # If you figure out the reason, please let me know -- thank you! - torch.cuda.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, - old_global_expert_indices[layer].tolist(), - new_global_expert_indices[layer].tolist(), + old_global_expert_indices_cpu[layer].tolist(), + new_global_expert_indices_cpu[layer].tolist(), expert_weights[layer], expert_weights_buffer, ep_group, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 354aa9a87183..f85eb414b222 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -204,11 +204,18 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: Returns: ConnectorMetadata: the connector metadata. """ - # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + return self._connector_metadata is not None + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d7bbf02c8367..c9d08e9b78ed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -171,16 +171,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. + # + # Note: Call the base class method to ensure metadata is also set on the + # MultiConnector instance itself; otherwise, `has_connector_metadata()` will + # always return False. def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) + super().bind_connector_metadata(connector_metadata) def clear_connector_metadata(self) -> None: for c in self._connectors: c.clear_connector_metadata() + super().clear_connector_metadata() def shutdown(self): exception: Exception | None = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a9b01e82562b..c78e6a32733c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1483,6 +1483,9 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + # Ensure all objects are not freezed before cleanup + gc.unfreeze() + destroy_model_parallel() destroy_distributed_environment() if shutdown_ray: diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/dynamic_lora.py new file mode 100644 index 000000000000..cc0f437e5c77 --- /dev/null +++ b/vllm/entrypoints/dynamic_lora.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import models, validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def register_dynamic_lora_routes(router: APIRouter): + @sagemaker_standards.register_load_adapter_handler( + request_shape={ + "lora_name": "body.name", + "lora_path": "body.src", + }, + ) + @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): + handler: OpenAIServingModels = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + @sagemaker_standards.register_unload_adapter_handler( + request_shape={ + "lora_name": "path_params.adapter_name", + } + ) + @router.post( + "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] + ) + async def unload_lora_adapter( + request: UnloadLoRAAdapterRequest, raw_request: Request + ): + handler: OpenAIServingModels = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + return router diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c8c8d5c034d5..fbb2d32a229d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import asyncio -import gc import hashlib import importlib import inspect @@ -21,6 +19,7 @@ from http import HTTPStatus from typing import Annotated, Any, Literal +import model_hosting_container_standards.sagemaker as sagemaker_standards import prometheus_client import pydantic import regex as re @@ -67,7 +66,6 @@ ErrorInfo, ErrorResponse, IOProcessorResponse, - LoadLoRAAdapterRequest, PoolingBytesResponse, PoolingRequest, PoolingResponse, @@ -84,7 +82,6 @@ TranscriptionResponse, TranslationRequest, TranslationResponse, - UnloadLoRAAdapterRequest, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_classification import ServingClassification @@ -118,6 +115,7 @@ from vllm.tasks import POOLING_TASKS from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.v1.engine.exceptions import EngineDeadError @@ -153,8 +151,7 @@ async def _force_log(): # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() + freeze_gc_heap() try: yield finally: @@ -389,13 +386,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.get("/ping", response_class=Response) -@router.post("/ping", response_class=Response) -async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - @router.post( "/tokenize", dependencies=[Depends(validate_json_request)], @@ -1238,47 +1228,6 @@ async def is_scaling_elastic_ep(raw_request: Request): ] -@router.post( - "/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def invocations(raw_request: Request): - """For SageMaker, routes requests based on the request type.""" - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}" - ) from e - - valid_endpoints = [ - (validator, endpoint) - for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None - ] - - for request_validator, endpoint in valid_endpoints: - try: - request = request_validator.validate_python(body) - except pydantic.ValidationError: - continue - - return await endpoint(request, raw_request) - - type_names = [ - t.__name__ if isinstance(t := validator._type, type) else str(t) - for validator, _ in valid_endpoints - ] - msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" - res = base(raw_request).create_error_response(message=msg) - return JSONResponse(content=res.model_dump(), status_code=res.error.code) - - if envs.VLLM_TORCH_PROFILER_DIR: logger.warning_once( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -1306,39 +1255,6 @@ async def stop_profile(raw_request: Request): return Response(status_code=200) -if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - - @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): - handler = models(raw_request) - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - @router.post( - "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] - ) - async def unload_lora_adapter( - request: UnloadLoRAAdapterRequest, raw_request: Request - ): - handler = models(raw_request) - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1608,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes + + register_dynamic_lora_routes(router) + + from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes + + register_sagemaker_routes(router) + app.include_router(router) app.root_path = args.root_path @@ -1698,6 +1628,8 @@ async def log_response(request: Request, call_next): f"Invalid middleware {middleware}. Must be a function or a class." ) + app = sagemaker_standards.bootstrap(app) + return app diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8ce4ff574699..30b8499b08d5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1375,6 +1375,8 @@ def _parse_tool_calls_from_content( for tool_call in tool_call_info.tool_calls ) content = tool_call_info.content + if content and content.strip() == "": + content = None else: # No tool calls. return None, content diff --git a/vllm/entrypoints/sagemaker/__init__.py b/vllm/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..c1767137e4ea --- /dev/null +++ b/vllm/entrypoints/sagemaker/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""SageMaker-specific integration for vLLM.""" diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py new file mode 100644 index 000000000000..498b7294f0d8 --- /dev/null +++ b/vllm/entrypoints/sagemaker/routes.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from http import HTTPStatus + +import model_hosting_container_standards.sagemaker as sagemaker_standards +import pydantic +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import ( + INVOCATION_VALIDATORS, + base, + health, + validate_json_request, +) +from vllm.entrypoints.openai.protocol import ErrorResponse + + +def register_sagemaker_routes(router: APIRouter): + @router.post("/ping", response_class=Response) + @router.get("/ping", response_class=Response) + @sagemaker_standards.register_ping_handler + async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, + ) + @sagemaker_standards.register_invocation_handler + @sagemaker_standards.stateful_session_manager() + @sagemaker_standards.inject_adapter_id(adapter_path="model") + async def invocations(raw_request: Request): + """For SageMaker, routes requests based on the request type.""" + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] + + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue + + return await endpoint(request, raw_request) + + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.error.code) + + return router diff --git a/vllm/envs.py b/vllm/envs.py index 078e5c38f0f4..52178e5f5250 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -109,7 +109,7 @@ VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_TRITON_ROPE: bool = False + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True @@ -147,6 +147,7 @@ VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True + VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", @@ -222,7 +223,7 @@ VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" - VLLM_FLATTEN_LOGPROBS: bool = False + VLLM_FLAT_LOGPROBS: bool = False def get_default_cache_root(): @@ -926,8 +927,8 @@ def get_vllm_port() -> int | None: ), # Whether to use aiter rope. # By default is disabled. - "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") + "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. @@ -1116,6 +1117,10 @@ def get_vllm_port() -> int | None: ), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), + # Allow use of DeepGemm specifically for MoE fused ops (overrides only MoE). + "VLLM_MOE_USE_DEEP_GEMM": lambda: bool( + int(os.getenv("VLLM_MOE_USE_DEEP_GEMM", "1")) + ), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) @@ -1476,11 +1481,11 @@ def get_vllm_port() -> int | None: "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), - # Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than + # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than # the original list[dict[int, Logprob]] approach. # After enabled, PromptLogprobs and SampleLogprobs would populated as - # FlattenLogprobs. - "VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), + # FlatLogprobs. + "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), } # --8<-- [end:env-vars-definition] @@ -1569,6 +1574,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_MOE_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP16", @@ -1589,7 +1595,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", - "VLLM_ROCM_USE_TRITON_ROPE", + "VLLM_ROCM_USE_AITER_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_AITER_TRITON_GEMM", diff --git a/vllm/logprobs.py b/vllm/logprobs.py index bf66e5f75c79..a34398db2c96 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -30,16 +30,16 @@ class Logprob: @dataclass -class FlattenLogprobs(MutableSequence[LogprobsOnePosition]): +class FlatLogprobs(MutableSequence[LogprobsOnePosition]): """ - Flatten logprobs of a request into multiple primitive type lists. + Flat logprobs of a request into multiple primitive type lists. Compared to list[dict[int, Logprob]], this data structure reduced GC overhead significantly. As it flattened logprob information for all positions and ranks in to multiple primitive type lists (i.e. logprobs, token_ids, ranks per token_ids, decoded_tokens). So regardless of the sequence length and top_logprobs setup, - FlattenLogprobs would only introduce a constant amount of objects. + FlatLogprobs would only introduce a constant amount of objects. As each position might contains different amount of ranks, start_indices_per_position would be used to access the logprob ranges @@ -107,7 +107,7 @@ def __len__(self) -> int: def __getitem__(self, position: int) -> LogprobsOnePosition: ... @overload - def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ... + def __getitem__(self, s: slice, /) -> "FlatLogprobs": ... def __getitem__(self, index: int | slice): """Extracts logprobs of a given position or slice""" @@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice): elif isinstance(index, slice): min_index = self.start_indices[index][0] max_index = self.end_indices[index][-1] - return FlattenLogprobs( + return FlatLogprobs( # Shift updated start_indices and end_indices to # be 0-indexed start_indices=[i - min_index for i in self.start_indices[index]], @@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice): raise TypeError(f"Invalid index type: {type(index)}") def __setitem__(self, item, value) -> None: - raise TypeError("Cannot set logprobs in FlattenLogprobs") + raise TypeError("Cannot set logprobs in FlatLogprobs") def __delitem__(self, item) -> None: - raise TypeError("Cannot delete logprobs from FlattenLogprobs") + raise TypeError("Cannot delete logprobs from FlatLogprobs") def insert(self, item) -> None: - raise TypeError("Cannot insert logprobs to FlattenLogprobs") + raise TypeError("Cannot insert logprobs to FlatLogprobs") def __iter__(self) -> Iterator[LogprobsOnePosition]: """ @@ -156,14 +156,14 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]: # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None] +PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None] # {token_id -> logprob} for each sequence group. -SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition] +SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] def create_prompt_logprobs() -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" - logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs @@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs: def create_sample_logprobs() -> SampleLogprobs: """Creates a container to store decode logprobs for a request""" - return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] def append_logprobs_for_next_position( @@ -191,7 +191,7 @@ def append_logprobs_for_next_position( topk_ranks = range(1, num_logprobs + 1) ranks = itertools.chain((rank,), topk_ranks) - if isinstance(request_logprobs, FlattenLogprobs): + if isinstance(request_logprobs, FlatLogprobs): request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) else: request_logprobs.append( diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 095ec966ea7e..b8a97e92ab79 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, group_size: int = 128, + use_ue8m0: bool | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is @@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant( device=y.device, ) - use_ue8m0 = is_deep_gemm_e8m0_used() + use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() cuda_arch = current_platform.get_device_capability( device_id=y.device.index diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbc3caafcf2f..a7bd64b1c65e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from enum import IntEnum from typing import Optional, Union import torch @@ -91,6 +92,26 @@ def _quant_flags_to_group_shape( return a_shape, w_shape +# The type of method in top-K routing +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups + # -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # RenormalizeNaive: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK: TopK (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6.0 + + @dataclass class FusedMoEQuantDesc: """ diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 484b8aa9d107..86cdd25f2c87 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -215,7 +215,7 @@ def workspace_shapes( ) assert M_sum % block_m == 0 - workspace1 = (M_sum, max(N, K)) + workspace1 = (M_sum, N) workspace2 = (M_sum, max(N // 2, K)) output = (M, K) return (workspace1, workspace2, output) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index f21fe16c5108..51e06ac54f49 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8( w2_weight_scale_inv: torch.Tensor, global_num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: float = 1.0, + routing_method_type: int = RoutingMethodType.DeepSeekV3, + routed_scaling: float | None = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 + assert top_k <= 10 assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 256 - assert global_num_experts <= 256 + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! @@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim( - x.shape[0], top_k, global_num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method + tile_tokens_dim=None, + routing_method_type=routing_method_type, use_shuffled_weight=False, ) @@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( expert_offset: int, local_num_experts: int, block_shape: list[int], + routing_method_type: int, routed_scaling: float = 1.0, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7ad3ce1397b3..2e042d85fcfc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,6 +14,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -55,8 +56,6 @@ from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer -from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled - logger = init_logger(__name__) @@ -1089,11 +1088,11 @@ def vllm_topk_softmax( return topk_weights, topk_indices -def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax - - return rocm_aiter_topk_softmax +def dispatch_topk_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_softmax return vllm_topk_softmax @@ -1121,7 +1120,7 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - topk_func = dispatch_topk_func() + topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) topk_weights, topk_ids = topk_func( topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2b8280f941e3..39547cc83c7b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,6 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( @@ -30,6 +31,7 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton @@ -41,8 +43,6 @@ ) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, ) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( @@ -92,13 +92,11 @@ def _eplb_map_to_physical_and_record( return topk_ids eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record +from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_grouped_topk, +) -if is_rocm_aiter_moe_enabled(): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk_aiter, - ) -else: - from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: @@ -463,7 +461,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -620,13 +619,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Padding the weight for better performance on ROCm layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - # Lazy import to avoid importing triton. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights, - ) if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -1002,6 +997,7 @@ def determine_expert_map( global_num_experts: int, expert_placement_strategy: ExpertPlacementStrategy = "linear", num_fused_shared_experts: int = 0, + return_expert_mask: bool = False, ) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: """ Calculates how many experts should be assigned to each rank for EP and @@ -1064,7 +1060,7 @@ def determine_expert_map( ) expert_mask = None - if is_rocm_aiter_moe_enabled(): + if return_expert_mask: expert_mask = torch.ones( (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 ) @@ -1218,6 +1214,7 @@ def __init__( zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, + routing_method_type: int | None = None, ): super().__init__() @@ -1292,14 +1289,18 @@ def __init__( self.logical_replica_count: torch.Tensor | None = None # ROCm aiter shared experts fusion + self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.aiter_fmoe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + self.num_fused_shared_experts = ( n_shared_experts - if n_shared_experts is not None - and is_rocm_aiter_fusion_shared_expert_enabled() + if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled else 0 ) if ( - not is_rocm_aiter_fusion_shared_expert_enabled() + not self.aiter_fmoe_shared_expert_enabled and self.num_fused_shared_experts != 0 ): raise ValueError( @@ -1346,6 +1347,7 @@ def __init__( global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) @@ -1397,6 +1399,24 @@ def __init__( "Only softmax scoring function is supported for non-grouped topk." ) + # ToDo: Better logic to determine the routing method type + if routing_method_type is not None: + self.routing_method_type = routing_method_type + else: + if scoring_func == "sigmoid": + if self.use_grouped_topk: + self.routing_method_type = RoutingMethodType.DeepSeekV3 + elif self.top_k == 1: + self.routing_method_type = RoutingMethodType.Llama4 + elif self.scoring_func == "softmax": + self.routing_method_type = ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + else: + self.routing_method_type = RoutingMethodType.TopK + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, @@ -1570,13 +1590,16 @@ def update_expert_map(self): ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) - self._init_aiter_shared_experts_topK_buffer( - vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size - ) + if self.aiter_fmoe_shared_expert_enabled: + self._init_aiter_shared_experts_topK_buffer( + vllm_config=get_current_vllm_config(), + dp_size=get_dp_group().world_size, + ) def _load_per_tensor_weight_scale( self, @@ -1753,20 +1776,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _init_aiter_shared_experts_topK_buffer( self, vllm_config: VllmConfig, dp_size: int ): - if is_rocm_aiter_fusion_shared_expert_enabled(): - if self.num_fused_shared_experts > 0: - init_aiter_topK_meta_data( - n_routed_experts=self.global_num_experts, - n_shared_experts=self.num_fused_shared_experts, - top_k=self.top_k, - tp_rank=self.ep_rank if self.use_ep else self.tp_rank, - tp_size=self.ep_size if self.use_ep else self.tp_size, - shared_experts_score=1.0, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens - * dp_size, - is_EP=self.use_ep, - ) - self.local_num_experts += self.num_fused_shared_experts + if self.num_fused_shared_experts > 0: + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens + * dp_size, + is_EP=self.use_ep, + ) + self.local_num_experts += self.num_fused_shared_experts @overload def weight_loader( @@ -2118,14 +2140,6 @@ def set_eplb_state( self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] - def get_sp_ctx(self): - ctx = get_forward_context() - return ( - ctx.dp_metadata.sp_local_sizes(self.sp_size) - if ctx.dp_metadata - else nullcontext() - ) - def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: self.quant_method.moe_quant_config = ( @@ -2216,15 +2230,16 @@ def select_experts( elif use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - if is_rocm_aiter_moe_enabled(): - if not is_rocm_aiter_fusion_shared_expert_enabled(): + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): assert num_fused_shared_experts == 0 grouped_topk_impl = partial( - grouped_topk_aiter, + rocm_aiter_grouped_topk, num_fused_shared_experts=num_fused_shared_experts, ) else: grouped_topk_impl = grouped_topk + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, @@ -2340,35 +2355,16 @@ def forward_native( mode="constant", value=0.0, ) - do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( - self.quant_method, FusedMoEModularMethod - ) - sp_ctx = self.get_sp_ctx() - with sp_ctx: - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel - ) - - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - # Reinitialize the context manager - # as it was reset in the forward_impl. - sp_ctx = self.get_sp_ctx() - with sp_ctx: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and not self.use_dp_chunking - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - return states + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states if self.shared_experts is None: if current_platform.is_tpu(): @@ -2400,7 +2396,7 @@ def reduce_output( hidden_states, router_logits, self.layer_name ) return ( - reduce_output(shared_output, do_combine=False)[..., :og_hidden_states], + reduce_output(shared_output)[..., :og_hidden_states], reduce_output(fused_output)[..., :og_hidden_states], ) @@ -2460,28 +2456,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) - # If there are shared experts but we are not using a modular kernel, - # the shared experts must be called here - if has_separate_shared_experts: - assert self.shared_experts is not None - - if self.shared_experts_stream is not None: - # For chunked, we start the shared experts stream here - # (Note that no concurrency with the router/gate) - self.shared_experts_stream.wait_stream(current_stream()) - - with torch.cuda.stream(self.shared_experts_stream): - # Note that staged_hidden_states clone() is necessary - # here to avoid conflict with the main stream - shared_output = self.shared_experts( - staged_hidden_states.clone() - ) - else: - shared_output = self.shared_experts(staged_hidden_states) - - else: - shared_output = None - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -2492,7 +2466,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, @@ -2510,11 +2484,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None - - # Here we finish the shared experts stream - if self.shared_experts_stream is not None: - current_stream().wait_stream(self.shared_experts_stream) - + shared_output = self.shared_experts(staged_hidden_states) final_hidden_states = ( shared_output, final_hidden_states, @@ -2613,25 +2583,50 @@ def forward_impl( hidden_states, router_logits, has_separate_shared_experts ) + do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( + self.quant_method, FusedMoEModularMethod + ) + # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here if has_separate_shared_experts: assert self.shared_experts is not None if self.shared_experts_stream is not None: + # Clone BEFORE switching streams to avoid race condition + # where routed_expert kernel may mutate hidden_states. + hidden_states_clone = hidden_states.clone() + self.shared_experts_stream.wait_stream(current_stream()) + # Run shared experts in parallel on a separate stream with torch.cuda.stream(self.shared_experts_stream): - # Note that hidden_states clone() is necessary here to avoid - # conflict with the main stream - shared_output = self.shared_experts(hidden_states.clone()) + shared_output = self.shared_experts(hidden_states_clone) + + # Record that the clone will be used by shared_experts_stream + # to avoid gc issue from deallocation of hidden_states_clone + # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 + # NOTE: we dont need shared_output.record_stream(current_stream()) + # because we synch the streams before using shared_output. + hidden_states_clone.record_stream(self.shared_experts_stream) + else: shared_output = self.shared_experts(hidden_states) else: shared_output = None - sp_ctx = self.get_sp_ctx() + ctx = get_forward_context() + sp_ctx = ( + ctx.dp_metadata.sp_local_sizes(self.sp_size) + if ctx.dp_metadata + else nullcontext() + ) with sp_ctx: + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits, self.is_sequence_parallel + ) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -2642,7 +2637,7 @@ def forward_impl( use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, @@ -2670,7 +2665,25 @@ def forward_impl( shared_output, final_hidden_states, ) - return final_hidden_states + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, tuple) + final_hidden_states, zero_expert_result = final_hidden_states + + def combine_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + return states + + if self.shared_experts is not None: + return ( + final_hidden_states[0], + combine_output(final_hidden_states[1]), + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, torch.Tensor) + return (combine_output(final_hidden_states), zero_expert_result) + else: + return combine_output(final_hidden_states) @classmethod def make_expert_params_mapping( diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index e18514ad43f6..8f05828d74f5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum -from functools import cache, lru_cache +from functools import lru_cache import torch -from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, ) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): @@ -37,27 +35,6 @@ class ActivationMethod(IntEnum): GELU = 1 -@cache -def is_rocm_aiter_moe_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_MOE - and envs.VLLM_ROCM_USE_AITER - ) - - -@cache -def use_mxfp4_aiter_moe() -> bool: - return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER - - -@cache -def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: - return ( - envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled() - ) - - aiter_topK_meta_data = None @@ -114,250 +91,6 @@ def init_aiter_topK_meta_data( aiter_topK_meta_data = (total_topk_weights, total_topk_ids) -def rocm_aiter_asm_moe_tkw1_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - from aiter import ActivationType - from aiter.fused_moe_bf16_asm import asm_moe_tkw1 - - activation = ActivationType(activation_method) - - return asm_moe_tkw1( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - per_tensor_quant_scale=per_tensor_quant_scale, - expert_mask=expert_mask, - activation=activation, - ) - - -def rocm_aiter_asm_moe_tkw1_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_topk_softmax_impl( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - from aiter import topk_softmax - - topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - - -def rocm_aiter_topk_softmax_fake( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - pass - - -def rocm_aiter_biased_grouped_topk_impl( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import biased_grouped_topk - - biased_grouped_topk( - gating_output, - correction_bias, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - routed_scaling_factor, - ) - - -def rocm_aiter_biased_grouped_topk_fake( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_grouped_topk_impl( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import grouped_topk - - grouped_topk( - gating_output, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - scoring_func, - routed_scaling_factor, - ) - - -def rocm_aiter_grouped_topk_fake( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_fused_moe_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - activation = ActivationType(activation_method) - quant_type = QuantType(quant_method) - - return fused_moe( - hidden_states, - w1, - w2, - topk_weight, - topk_ids, - expert_mask, - activation, - quant_type, - doweight_stage1, - w1_scale, - w2_scale, - a1_scale, - a2_scale, - ) - - -def rocm_aiter_fused_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_asm_moe_tkw1", - op_func=rocm_aiter_asm_moe_tkw1_impl, - fake_impl=rocm_aiter_asm_moe_tkw1_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_fused_moe", - op_func=rocm_aiter_fused_moe_impl, - fake_impl=rocm_aiter_fused_moe_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_topk_softmax", - op_func=rocm_aiter_topk_softmax_impl, - mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], - fake_impl=rocm_aiter_topk_softmax_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_biased_grouped_topk", - op_func=rocm_aiter_biased_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_biased_grouped_topk_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_grouped_topk", - op_func=rocm_aiter_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_grouped_topk_fake, - ) - - def rocm_aiter_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk( ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): assert aiter_topK_meta_data is not None, ( "AITER topK meta data is not initialized. " "Please ensure that init_aiter_topK_meta_data " @@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk( topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: - torch.ops.vllm.rocm_aiter_biased_grouped_topk( + rocm_aiter_ops.biased_grouped_topk( gating_output, e_score_correction_bias.to(gating_output.dtype), topk_weights, @@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk( ) else: assert scoring_func == "softmax" or scoring_func == "sigmoid" - torch.ops.vllm.rocm_aiter_grouped_topk( + rocm_aiter_ops.grouped_topk( gating_output, topk_weights, topk_ids, @@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk( routed_scaling_factor=routed_scaling_factor, ) - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): return total_topk_weights, total_topk_ids return topk_weights, topk_ids @@ -464,7 +203,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + return rocm_aiter_ops.asm_moe_tkw1( hidden_states, w1, w2, @@ -482,7 +221,9 @@ def rocm_aiter_fused_experts( else: quant_method = QuantMethod.NO.value - + # quark moe for mxfp4 w_dtype + if quant_config.use_mxfp4_w4a16: + quant_method = QuantMethod.BLOCK_1X32.value # w8a8 block-scaled if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( @@ -507,7 +248,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_fused_moe( + return rocm_aiter_ops.fused_moe( hidden_states, w1, w2, @@ -522,39 +263,3 @@ def rocm_aiter_fused_experts( a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, ) - - -def rocm_aiter_topk_softmax( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> tuple[torch.Tensor, ...]: - torch.ops.vllm.rocm_aiter_topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - return topk_weights, topk_indices - - -def shuffle_weights( - *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) -) -> tuple[torch.Tensor, ...]: - """ - Applies shuffle_weight function from AITER to each - input tensor and returns them. - - Rearranges (shuffles) the input tensor/s - into a specified block layout for optimized computation. - - Args: - *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the block sizes used to divide - the tensors during shuffling. Default is (16, 16). - - Returns: - A Tuple of shuffled tensors. - """ - from aiter.ops.shuffle import shuffle_weight - - return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 6b4a0b8cf073..3d0c5636d6c0 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -28,13 +28,18 @@ def __init__( super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if we are not using - # flashinfer + DP since there is nothing to be gained in this case. - # Disabling the overlap optimization also prevents the shared experts - # from being hidden from torch.compile. + # Disable shared expert overlap if we are using eplb, because of + # correctness issues, or if using flashinfer with DP, since there + # is nothing to be gained in this case. Disabling the overlap + # optimization also prevents the shared experts from being hidden + # from torch.compile. self.use_overlapped = ( use_overlapped - and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + and not ( + # TODO(wentao): find the root cause and remove this condition + self.enable_eplb + or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a883ac81f41e..8cc374ac9155 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,18 +6,13 @@ import torch.nn as nn import torch.nn.functional as F -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, vllm_is_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_aiter_rmsnorm_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER def rms_norm( @@ -58,80 +53,34 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm_impl( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +def poly_norm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: - import aiter as rocm_aiter - - if x.dim() > 2: - x_original_shape = x.shape - x = x.reshape(-1, x_original_shape[-1]) - x = rocm_aiter.rms_norm(x, weight, variance_epsilon) - return x.reshape(x_original_shape) - - return rocm_aiter.rms_norm(x, weight, variance_epsilon) - + from vllm import _custom_ops as ops -def rocm_aiter_rmsnorm2d_fwd_with_add_impl( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - import aiter as rocm_aiter - - residual_out = torch.empty_like(residual) - output = torch.empty_like(x) - rocm_aiter.rmsnorm2d_fwd_with_add( - output, # output - x, # input - residual, # residual input - residual_out, # residual output + out = torch.empty_like(x) + ops.poly_norm( + out, + x, weight, + bias, variance_epsilon, ) - return output, residual_out - - -def rocm_aiter_rms_norm_fake( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - return torch.empty_like(x) - - -def rocm_aiter_rmsnorm2d_fwd_with_add_fake( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_rms_norm", - op_func=rocm_aiter_rms_norm_impl, - fake_impl=rocm_aiter_rms_norm_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_rmsnorm2d_fwd_with_add", - op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, - fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, - ) + return out -def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): - use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ +def dispatch_rocm_rmsnorm_func( + with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False +): + use_aiter = use_aiter and dtype in [ torch.float16, torch.bfloat16, ] if use_aiter and with_fused_add: - return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + return rocm_aiter_ops.rms_norm2d_with_add if use_aiter: - return torch.ops.vllm.rocm_aiter_rms_norm + return rocm_aiter_ops.rms_norm # fall back to CUDA implementation if with_fused_add: @@ -169,11 +118,14 @@ def __init__( self.weight = nn.Parameter(self.weight) if current_platform.is_rocm(): + aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, dtype=weight_dtype + with_fused_add=False, + dtype=weight_dtype, + use_aiter=aiter_rmsnorm_enabled, ) self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype + with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled ) @staticmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d95d49eddfe3..59567f2ca13c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -582,11 +583,8 @@ def __init__( # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( @@ -829,12 +827,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -972,10 +966,18 @@ def select_gemm_impl( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), ) else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) + return TritonOrDeepGemmExperts( + self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b8..6da136cbc8f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -7,12 +7,12 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) ) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() if self.weight_block_size is not None: assert not self.is_static_input_scheme diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ce40645782e5..83d136600b77 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -27,6 +28,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe @@ -41,7 +43,6 @@ QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -56,7 +57,6 @@ ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -94,11 +94,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( - fp8_gemm_nt, get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear, ) from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.import_utils import has_deep_gemm @@ -160,7 +158,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: return Fp8MoeBackend.MARLIN # deepGEMM on supported platforms with block-quantized weights - if envs.VLLM_USE_DEEP_GEMM and block_quant: + if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: if not has_deep_gemm(): logger.warning_once("DeepGEMM backend requested but not available.") elif is_deep_gemm_supported(): @@ -369,7 +367,7 @@ def __init__(self, quant_config: Fp8Config): if vllm_is_batch_invariant(): self.use_marlin = False - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size @@ -553,83 +551,19 @@ def apply( # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): - # Call is_deep_gemm_supported() ahead of time for torch.compile - # dynamo has trouble tracing through - if self.block_quant and should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight, self.use_deep_gemm - ): - # use group quant consistent with block size across K - assert self.act_q_group_shape is not None - q_input, input_scale = QuantFP8( - False, - self.act_q_group_shape, - column_major_scales=True, - )(x) - - output_2d = torch.empty( - (q_input.shape[0], layer.weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device, - ) - fp8_gemm_nt( - (q_input, input_scale), - (layer.weight, layer.weight_scale), - output_2d, - ) - if bias is not None: - output_2d = output_2d + bias - return output_2d - - # Dequantize FP8 weights to BF16 - weight_fp8 = layer.weight.to(torch.bfloat16) - weight_scale = layer.weight_scale.to(torch.bfloat16) - - # Handle different quantization granularities if self.block_quant: - # Block-wise quantization: - # - Weight is NOT transposed, shape is [N, K] (output_size, input_size) - # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!) assert self.weight_block_size is not None - block_n, block_k = self.weight_block_size # Note: order is [N, K] - - N, K = weight_fp8.shape - - # determine expected number of blocks along N and K - num_blocks_n = (N + block_n - 1) // block_n - num_blocks_k = (K + block_k - 1) // block_k - - # scale layout may be [num_blocks_n, num_blocks_k] - # or [num_blocks_k, num_blocks_n] depending on backend - if weight_scale.dim() != 2: - raise RuntimeError( - f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}" - ) - - scale_rows, scale_cols = weight_scale.shape - if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): - if num_blocks_n == num_blocks_k: - # ambiguous square case, warn and skip transpose - logger.warning( - "Batch-invariant FP8: square block-scale %dx%d; " - "skipping transpose to avoid misorientation.", - scale_rows, - scale_cols, - ) - else: - # clear KN -> transpose to NK - weight_scale = weight_scale.t() - - # Expand scale to match weight dimensions - # scale_expanded should have shape [N, K] - scale_expanded = weight_scale.repeat_interleave( - block_n, dim=0 - ).repeat_interleave(block_k, dim=1) - # Trim to exact weight size (in case of padding) - scale_expanded = scale_expanded[:N, :K] - weight_bf16 = weight_fp8 * scale_expanded + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) else: - # Per-tensor quantization: weight IS transposed to [K, N] - # scale should be scalar or [1] or per-output-channel [N] + # per-tensor/channel: dequant to BF16 and run GEMM + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) if weight_scale.numel() == 1: # Per-tensor: simple scalar multiplication weight_bf16 = weight_fp8 * weight_scale @@ -648,16 +582,7 @@ def apply( else: # Fallback weight_bf16 = weight_fp8 * weight_scale - - # For block quant, weight is [N, K], for per-tensor it's [K, N] - # F.linear expects weight to be [N, K], so: - if self.block_quant: - # Already in correct shape [N, K] - output = torch.nn.functional.linear(x, weight_bf16, bias) - else: - # Need to transpose back: [K, N] -> [N, K] - output = torch.nn.functional.linear(x, weight_bf16.t(), bias) - return output + return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: return apply_fp8_marlin_linear( @@ -869,12 +794,8 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - shuffle_weights, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -916,7 +837,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -962,7 +883,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1042,7 +963,7 @@ def process_weights_after_loading(self, layer: Module) -> None: start += shard_size if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1226,22 +1147,20 @@ def apply( assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) + if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert ( - renormalize and use_grouped_topk and custom_routing_function is None - ) e_score_correction_bias = ( e_score_correction_bias.to(x.dtype) if e_score_correction_bias is not None else None ) + routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, @@ -1256,6 +1175,7 @@ def apply( expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, + routing_method_type=routing_method_type, routed_scaling=routed_scaling_factor, ) else: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bc..f5cd91469b78 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -4,54 +4,14 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig -def rocm_aiter_gemm_w8a8_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - from aiter import gemm_a8w8_CK - - # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects - # a to be [M, K] - # b to be [N, K] - # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) - - -def rocm_aiter_gemm_w8a8_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8", - op_func=rocm_aiter_gemm_w8a8_impl, - fake_impl=rocm_aiter_gemm_w8a8_fake, - ) - - class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -75,7 +35,7 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + "installed on ROCm.", ) # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + if not (rocm_aiter_ops.is_linear_enabled()): return ( False, "AiterScaledMMLinearKernel is disabled. " @@ -157,6 +117,4 @@ def apply_weights( # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return torch.ops.vllm.rocm_aiter_gemm_w8a8( - x_q, w_q.t(), x_s, w_s, bias, out_dtype - ) + return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index eca6b0cb1d8e..30772c3665b0 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -21,10 +22,6 @@ ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - use_mxfp4_aiter_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) @@ -122,7 +119,7 @@ def __init__( if current_platform.is_rocm(): self.use_marlin = False - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() def create_weights( self, @@ -309,12 +306,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -470,13 +463,15 @@ def __init__( "not implemented. Please open an issue." ) + self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() + self.emulate = not current_platform.supports_mx() or not ( - use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" ) if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " + f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " @@ -656,28 +651,18 @@ def apply( ) if not self.emulate: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - aiter_acts = { - ActivationType.No.name.lower(): ActivationType.No, - ActivationType.Silu.name.lower(): ActivationType.Silu, - ActivationType.Gelu.name.lower(): ActivationType.Gelu, - } - assert activation in aiter_acts, ( - f"Aiter CK fp4 MoE doesn't support activation {activation}" - ) - out = fused_moe( + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + out = rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights, - topk_ids, - quant_type=QuantType.per_1x32, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - activation=aiter_acts[activation], - doweight_stage1=False, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + quant_config=self.moe_quant_config, ) else: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c25c522dea55..007e78e68d5c 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -31,6 +31,13 @@ logger = init_logger(__name__) +# TODO: move registration of custom op to aiter_ops.py +# `from vllm._aiter_ops import rocm_aiter_ops` +# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()` +# for envs checks which does not require @cache anymore. +# triton kernel is torch compile compatible. +# does not require direct registeration. +# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`. @cache def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: return ( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 50ea049c3d5a..e49d374f154d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. tile_tokens_dim = 8 - # from flashinfer import next_positive_power_of_2 - - # # Guess tokens per expert assuming perfect expert distribution first. - # num_tokens_per_expert = (num_tokens * top_k) // num_experts - # # And pad the number to the next power of 2. - # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # # kernel. - # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + # A factor considering tokens are not perfectly balanced among experts. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-max_tile_tokens_dim tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7fecda2166ef..c63196b89357 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -68,78 +69,6 @@ def cutlass_scaled_mm( ) -def rocm_aiter_gemm_w8a8_blockscale_impl( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - def is_aiter_triton_kernel_tuned(n, k): - return (n, k) in [ - (1024, 8192), - (2112, 7168), - (3072, 1536), - (32768, 8192), - (4096, 7168), - (4608, 7168), - (512, 7168), - (7168, 2048), - (7168, 256), - (8192, 1024), - (8192, 32768), - ] - - n, k = weight.shape - if input_scale is not None: - q_input = input_2d - elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k): - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - - # MI350 case uses triton kernel - q_input, input_scale = per_token_group_quant_fp8( - input_2d, - group_size, - column_major_scales=False, - use_ue8m0=False, - ) - else: - # MI300 uses tuned AITER ASM/C++ kernel - import aiter as rocm_aiter - from aiter import gemm_a8w8_blockscale, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) - q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 - ) - - return gemm_a8w8_blockscale( - q_input, weight, input_scale, weight_scale, dtype=output_dtype - ) - - -def rocm_aiter_gemm_w8a8_blockscale_fake( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = input_2d.shape[0] - n = weight.shape[0] - return torch.empty(m, n, dtype=output_dtype, device=input_2d.device) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - ) - - # TODO we should be able to change the type of block_size to GroupShape # after we resolve GroupShape compilation issue # https://github.com/vllm-project/vllm/issues/25270 @@ -385,13 +314,40 @@ def _run_aiter( input_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - input_2d, + + n, k = weight.shape + + use_triton = ( + not current_platform.is_fp8_fnuz() + and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) + ) + + if use_triton: + gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale + else: + gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale + + if input_scale is not None: + q_input = input_2d + # MI350 case uses triton kernel + elif use_triton: + q_input, input_scale = per_token_group_quant_fp8( + input_2d, + self.act_quant_group_shape.col, + column_major_scales=False, + use_ue8m0=False, + ) + # MI300 uses tuned AITER ASM/C++ kernel + else: + q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) + + return gemm_a8w8_blockscale_op( + q_input, weight, input_scale, weight_scale, - self.act_quant_group_shape.col, - input_2d.dtype, + list(self.weight_group_shape), + output_dtype=input_2d.dtype, ) def _run_triton( @@ -971,15 +927,6 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) -def check_aiter_fp8_linear_support() -> bool: - """AITER is only supported on ROCm for MI3XX""" - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - ) - - def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 380431e86435..7fe902807a74 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -472,7 +472,7 @@ def apply( # Example: # When the number of token is 1, per-token scale is [[1]] # When per-tensor scale is [1] or (). - per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 + per_tensor_weights = weight_scale.numel() == 1 per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 # TODO(luka) do this dispatch during init (after ScaledMM refactor) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 91276320df4d..2ef54e75df44 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,13 +4,10 @@ import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch -from .rocm_aiter_rope_ops import ( - is_rocm_triton_rotary_embedding_enabled, - rocm_aiter_rotary_emb, -) @CustomOp.register("rotary_embedding") @@ -48,8 +45,8 @@ def __init__( cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_triton_rotary_embedding_enabled = ( - is_rocm_triton_rotary_embedding_enabled() + self.is_rocm_triton_rotary_embed_enabled = ( + rocm_aiter_ops.is_triton_rotary_embed_enabled() ) def _compute_inv_freq(self, base: float) -> torch.Tensor: @@ -169,9 +166,9 @@ def forward_hip( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if self.is_rocm_triton_rotary_embedding_enabled: + if self.is_rocm_triton_rotary_embed_enabled: self._match_cos_sin_cache_dtype(query) - rocm_aiter_rotary_emb( + rocm_aiter_ops.triton_rotary_embed( positions, query, key, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index d9134f05fddf..e72834e473c1 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -146,6 +146,15 @@ def forward_native( key = key_rot return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.forward_native(positions, query, key, offsets) + def forward_cuda( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py deleted file mode 100644 index a01d14f7b3a1..000000000000 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.envs as envs -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_triton_rotary_embedding_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_TRITON_ROPE - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter.ops.triton.rope as ops - - ops.rope_cached_thd_positions_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -if is_rocm_triton_rotary_embedding_enabled(): - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_with_key_forward_triton", - op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, - mutates_args=["key", "query"], - fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake, - dispatch_key=current_platform.dispatch_key, - ) - - -def rocm_aiter_rotary_emb( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - cos_sin_cache: torch.Tensor, - head_size: int, - rotary_dim: int, - is_neox_style: bool, -): - num_tokens = positions.numel() - cos, sin = cos_sin_cache.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - rotate_style = 0 if is_neox_style else 1 - - query = query.view(num_tokens, -1, head_size) - key = key.view(num_tokens, -1, head_size) - query_ = query[..., :rotary_dim] - key_ = key[..., :rotary_dim] - positions = positions.view(*query.shape[:1]) - torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton( - positions, - sin, - cos, - query_, - key_, - rotate_style, - False, - ) - query = query.view(query_shape) - key = key.view(key_shape) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 63eaf63cc3c4..38189e17f7d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -33,6 +33,7 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton @@ -50,10 +51,6 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -294,10 +291,8 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - if ( - config.n_shared_experts is None - or is_rocm_aiter_fusion_shared_expert_enabled() - ): + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled: self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -330,14 +325,14 @@ def __init__( # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 - if not is_rocm_aiter_moe_enabled() + if not self.is_rocm_aiter_moe_enabled else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, n_shared_experts=config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() else None, ) @@ -371,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - if not is_rocm_aiter_moe_enabled(): + if not self.is_rocm_aiter_moe_enabled: final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None @@ -1428,6 +1423,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1456,7 +1454,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: num_experts=self.config.n_routed_experts + ( self.config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_moe_shared_expert_enabled else 0 ), num_redundant_experts=self.num_redundant_experts, @@ -1472,9 +1470,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if spec_layer is not None: continue # skip spec decode layers for main model - is_fuse_shared_experts_layer = ( - is_rocm_aiter_fusion_shared_expert_enabled() - and ("mlp.shared_experts" in name) + is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( + "mlp.shared_experts" in name ) for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 7c1eba103ae7..f287cff12086 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1435,8 +1435,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1569,7 +1567,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 121e84469c52..b9cd3545ec45 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1622,8 +1622,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1754,7 +1752,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 2de1e4810952..ebf6934dddea 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -625,8 +625,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -758,7 +756,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b634c7ec7d67..d6a8f86d998b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -995,8 +995,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1012,8 +1010,6 @@ def get_mrope_input_positions( image_grid_thw: Image grid dimensions (t, h, w) video_grid_thw: Video grid dimensions (t, h, w) second_per_grid_ts: Seconds per grid timestep for videos - context_len: Context length - seq_len: Sequence length audio_feature_lengths: Audio feature lengths for multimodal models use_audio_in_video: Whether to use audio in video for interleaving diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 5f8659a3064e..42f16ad9f3b3 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1630,8 +1630,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1759,6 +1757,5 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 13e5b2d5f157..6f95a59d36d2 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -600,8 +600,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -729,6 +727,5 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index b54f53931d71..b79bdf8595ca 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -456,7 +456,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - if not config.use_mla: + use_mha = ( + config.model_type == "deepseek" + or config.qk_nope_head_dim + config.qk_rope_head_dim == 0 + ) + if use_mha: stacked_params_mapping += [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 377b41a35578..12ae15699e7d 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -198,23 +198,18 @@ def get_num_image_tokens( if image_processor is None: image_processor = self.get_image_processor() - do_resize = True hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size - - if do_resize: - resized_height, resized_width = smart_resize( - height=image_height, - width=image_width, - factor=patch_size * merge_size, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, - ) - preprocessed_size = ImageSize(width=resized_width, height=resized_height) - else: - preprocessed_size = ImageSize(width=image_width, height=image_height) + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) grid_t = 1 grid_h = preprocessed_size.height // patch_size @@ -227,8 +222,19 @@ def get_num_image_tokens( def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() - image_size = hf_config.vision_config.image_size - return ImageSize(height=image_size, width=image_size) + + # See `smart_resize` for the calculation of the image size. + merge_size = hf_config.vision_config.spatial_merge_size + patch_size = hf_config.vision_config.patch_size + factor = merge_size * patch_size + max_num_tokens = self.get_image_processor().max_pixels // (factor**2) + # Find factors of max_num_tokens close to its square root + # to create a dummy image with a reasonable aspect ratio. + h_patches = int(math.sqrt(max_num_tokens)) + while max_num_tokens % h_patches != 0: + h_patches -= 1 + w_patches = max_num_tokens // h_patches + return ImageSize(height=h_patches * factor, width=w_patches * factor) class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]): @@ -1179,8 +1185,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1293,7 +1297,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 7e970ebbe2bb..fac281d2caf4 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -927,8 +927,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1125,7 +1123,6 @@ def get_mrope_input_positions( mrope_position_delta = ( torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) ) - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d337f1606943..48834ba699e4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1118,8 +1118,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1232,7 +1230,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9206ac8f9d03..b3999e6c934e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1240,8 +1240,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1360,7 +1358,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e6772bb708..d57b82cb0227 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,6 +43,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -171,6 +172,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 55bbad7a8b27..ddb8693c16e2 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,6 +34,7 @@ fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) @@ -173,6 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -585,7 +587,7 @@ def _forward_core( self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_decodes + : attn_metadata.num_actual_tokens ], validate_data=True, ) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f20e67902721..da489a812f55 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -1417,8 +1417,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 2d8f431bb8fa..fe0124ef3258 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1419,8 +1419,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1519,7 +1517,7 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta def get_language_model(self) -> torch.nn.Module: diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 10abd8659536..476074542e6a 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -371,8 +371,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -390,7 +388,7 @@ def get_mrope_input_positions( video_grid_thw=video_grid_thw, ) - mrope_positions = mrope_positions[:, 0, context_len:seq_len] + mrope_positions = mrope_positions[:, 0] mrope_position_delta = mrope_position_delta[0].item() return mrope_positions, mrope_position_delta diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index bdcebd498ef0..e0c584df8760 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -148,6 +148,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: + if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM): + return False + if not isinstance(module, FusedMoE): return False diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1abd6300036d..e6536a02a73d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: + from vllm._aiter_ops import rocm_aiter_ops + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER) + and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None ) @@ -202,12 +204,15 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: from importlib.util import find_spec + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + if rocm_aiter_ops.is_mha_enabled(): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. return _Backend.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: @@ -228,19 +233,23 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") - if use_mla: - from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( - is_aiter_mla_enabled, + + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) + if use_mla: if selected_backend is None: selected_backend = ( _Backend.ROCM_AITER_MLA - if is_aiter_mla_enabled() or block_size == 1 + if rocm_aiter_ops.is_mla_enabled() or block_size == 1 else _Backend.TRITON_MLA ) @@ -265,12 +274,12 @@ def get_attn_backend_cls( logger.info("Using FlexAttention backend.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + rocm_aiter_ops.is_mha_enabled() ) or selected_backend == _Backend.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend.") return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + rocm_aiter_ops.is_triton_unified_attn_enabled() ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend.") return ( diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 4dd85ef26f34..160ac9ac263a 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -89,6 +89,21 @@ def handle(self, phase: str, info: dict[str, int]) -> None: ) +def freeze_gc_heap() -> None: + """ + Freeze all objects tracked by the garbage collector. It should be invoked + after server init / warmup, to reduce GC overhead from static objects + during serving time. + """ + # Ensure all static objects are pushed down to the oldest generation for + # freeze + gc.collect(0) + gc.collect(1) + gc.collect(2) + # Freeze all GC tracked objects + gc.freeze() + + def maybe_attach_gc_debug_callback() -> None: """ Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 40ce12c4bd75..e38f7bcfa44e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -198,6 +198,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, @@ -270,28 +271,15 @@ class QueryLenSupport(Enum): flashinfer_available = False -def is_rocm_aiter_fp8bmm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP8BMM - and envs.VLLM_ROCM_USE_AITER - ) - - -if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 - ) - - def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn - ): - DTYPE_MAX = torch.finfo(dtype).max - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) - scale = DTYPE_MAX / amax - x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() logger = init_logger(__name__) @@ -1109,6 +1097,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1158,7 +1147,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1187,7 +1176,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1196,7 +1185,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1208,10 +1197,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm( + x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) # Convert from (B, N, V) to (B, N * V) @@ -1571,7 +1559,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1600,7 +1588,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1609,7 +1597,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1958,7 +1946,6 @@ def forward( # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) @@ -1966,9 +1953,9 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm( + decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, self.W_K, self.W_K_scale, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 4ad7236eb1be..5757aeadba05 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,9 +6,8 @@ import torch -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionLayer -from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import ( @@ -22,10 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA - - class AiterMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -284,7 +279,7 @@ def _forward_decode( # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd( + rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c17b19b58c97..46dc1071b839 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,6 +38,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext logger = init_logger(__name__) @@ -259,49 +260,52 @@ def schedule(self) -> SchedulerOutput: continue # Schedule newly needed KV blocks for the request. - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens, - ) - - if new_blocks is not None: - # The request can be scheduled. - break - - # The request cannot be scheduled. - # Preempt the lowest-priority request. - if self.policy == SchedulingPolicy.PRIORITY: - preempted_req = max( - self.running, - key=lambda r: (r.priority, r.arrival_time), + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, ) - self.running.remove(preempted_req) - if preempted_req in scheduled_running_reqs: - scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[preempted_req.request_id] - req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop(preempted_req.request_id) - req_index -= 1 - else: - preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - preempted_req.num_preemptions += 1 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp - ) + if new_blocks is not None: + # The request can be scheduled. + break - self.waiting.prepend_request(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. Cannot schedule this request. - break + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + req_index -= 1 + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break if new_blocks is None: # Cannot schedule this request. @@ -599,13 +603,14 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) ) - ) # Construct the scheduler output. new_reqs_data = [ @@ -614,13 +619,14 @@ def schedule(self) -> SchedulerOutput: ) for req in scheduled_new_reqs ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, - scheduled_resumed_reqs, - num_scheduled_tokens, - scheduled_spec_decode_tokens, - req_to_new_blocks, - ) + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) # Record the request ids that were scheduled in this step. self.prev_step_scheduled_req_ids.clear() @@ -649,8 +655,8 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta - - self._update_after_schedule(scheduler_output) + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fba018432e0a..ffb5232e770d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc import os import queue import signal @@ -27,7 +26,10 @@ from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.utils.gc_utils import maybe_attach_gc_debug_callback +from vllm.utils.gc_utils import ( + freeze_gc_heap, + maybe_attach_gc_debug_callback, +) from vllm.utils.hashing import get_hash_fn_by_name from vllm.utils.network_utils import make_zmq_socket from vllm.utils.system_utils import decorate_logs, set_process_title @@ -61,6 +63,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -196,6 +199,10 @@ def __init__( self.step if self.batch_queue is None else self.step_with_batch_queue ) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + freeze_gc_heap() + def _initialize_kv_caches( self, vllm_config: VllmConfig ) -> tuple[int, int, KVCacheConfig]: @@ -315,17 +322,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, non_block=True) - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) - with self.log_error_detail(scheduler_output): - model_output = future.result() - if model_output is None: - model_output = self.model_executor.sample_tokens(grammar_output) - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step: schedule"): + scheduler_output = self.scheduler.schedule() + + with record_function_or_nullcontext("core step: execute_model"): + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) + + with record_function_or_nullcontext("core step: update_from_output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 @@ -363,32 +374,49 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): - scheduler_output = self.scheduler.schedule() - exec_future = self.model_executor.execute_model( - scheduler_output, non_block=True - ) + with record_function_or_nullcontext("core step_with_batch_queue: schedule"): + scheduler_output = self.scheduler.schedule() + with record_function_or_nullcontext( + "core step_with_batch_queue: execute_model" + ): + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) model_executed = scheduler_output.total_num_scheduled_tokens > 0 if scheduler_output.pending_structured_output_tokens: - # We need to defer sampling until we have processed the model output - # from the prior step. - deferred_scheduler_output = scheduler_output - # Block-wait for execute to return (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - assert exec_result is None + with record_function_or_nullcontext( + "core step_with_batch_queue: pending_structured_output_tokens" + ): + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + # Block-wait for execute to return + # (continues running async on the GPU). + with self.log_error_detail(scheduler_output): + exec_result = exec_future.result() + assert exec_result is None else: - # We aren't waiting for any tokens, get any grammar output immediately. - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with record_function_or_nullcontext( + "core step_with_batch_queue: get_grammar_bitmask" + ): + # We aren't waiting for any tokens, get any grammar + # output immediately. + grammar_output = self.scheduler.get_grammar_bitmask( + scheduler_output + ) # Block-wait for execute to return (continues running async on the GPU). with self.log_error_detail(scheduler_output): exec_result = exec_future.result() if exec_result is None: - # Call sample tokens. - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) + with record_function_or_nullcontext( + "core step_with_batch_queue: sample_tokens" + ): + # Call sample tokens. + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) else: # No sampling required (e.g. all requests finished). future = cast(Future[ModelRunnerOutput], exec_future) @@ -408,27 +436,34 @@ def step_with_batch_queue( # only be called when the scheduler contains requests or the queue # is non-empty. return None, False - - # Block until the next result is available. - future, scheduler_output = batch_queue.pop() - with self.log_error_detail(scheduler_output): - model_output = future.result() - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step_with_batch_queue: model_output"): + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + with record_function_or_nullcontext( + "core step_with_batch_queue: update_from_output" + ): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. - grammar_output = self.scheduler.get_grammar_bitmask( - deferred_scheduler_output - ) - future = self.model_executor.sample_tokens(grammar_output, non_block=True) - batch_queue.appendleft((future, deferred_scheduler_output)) + with record_function_or_nullcontext( + "core step_with_batch_queue: deferred_scheduler_output" + ): + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) + batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed @@ -622,11 +657,6 @@ def __init__( assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - # Mark the startup heap as static so that it's ignored by GC. - # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() - # If enable, attach GC debugger after static variable freeze. maybe_attach_gc_debug_callback() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e32c74aff313..d27d13840989 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats +from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -280,28 +281,32 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]: return [] # 1) Get EngineCoreOutput from the EngineCore. - outputs = self.engine_core.get_output() + with record_function_or_nullcontext("llm_genine step: get_output"): + outputs = self.engine_core.get_output() # 2) Process EngineCoreOutputs. - iteration_stats = IterationStats() if self.log_stats else None - processed_outputs = self.output_processor.process_outputs( - outputs.outputs, - engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats, - ) - self.output_processor.update_scheduler_stats(outputs.scheduler_stats) + with record_function_or_nullcontext("llm_genine step: process_outputs"): + iteration_stats = IterationStats() if self.log_stats else None + processed_outputs = self.output_processor.process_outputs( + outputs.outputs, + engine_core_timestamp=outputs.timestamp, + iteration_stats=iteration_stats, + ) + self.output_processor.update_scheduler_stats(outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. - self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + with record_function_or_nullcontext("llm_genine step: abort_requests"): + self.engine_core.abort_requests(processed_outputs.reqs_to_abort) # 4) Record stats - if self.logger_manager is not None and outputs.scheduler_stats is not None: - self.logger_manager.record( - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - mm_cache_stats=self.processor.stat_mm_cache(), - ) - self.do_log_stats_with_interval() + with record_function_or_nullcontext("llm_genine step: record_stats"): + if self.logger_manager is not None and outputs.scheduler_stats is not None: + self.logger_manager.record( + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), + ) + self.do_log_stats_with_interval() return processed_outputs.request_outputs diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 00a625e103bd..2962a439dcb3 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar): vocab_size: int printed_error: bool = False terminated: bool = False + rollback_lag: int = 0 def check_error(self): if not self.printed_error: @@ -127,6 +128,8 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """ if self.ll_tokenizer.eos_token in tokens: + if self.ll_matcher.is_stopped() and not self.terminated: + self.rollback_lag = 1 self.terminated = True if self.ll_matcher.is_stopped(): @@ -163,8 +166,11 @@ def validate_tokens(self, tokens: list[int]) -> list[int]: return tokens[:num_tokens] def rollback(self, num_tokens: int) -> None: - self.ll_matcher.rollback(num_tokens) - self.check_error() + if num_tokens > 0: + self.ll_matcher.rollback(num_tokens - self.rollback_lag) + self.terminated = False + self.rollback_lag = 0 + self.check_error() def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26007d29d61b..6fccf2ea2f47 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -279,6 +279,9 @@ def __init__( # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + + # Always set to false after the first forward pass + self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -2525,7 +2528,7 @@ def execute_model( "after execute_model() returns None." ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("Preprocess"): + with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) @@ -2625,16 +2628,12 @@ def execute_model( ) # Set cudagraph mode to none if calc_kv_scales is true. - if attn_metadata is not None: - metadata_list = ( - attn_metadata.values() - if isinstance(attn_metadata, dict) - else [attn_metadata] - ) - if any( - getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list - ): - cudagraph_runtime_mode = CUDAGraphMode.NONE + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False # Run the model. # Use persistent buffers for CUDA graphs. @@ -2648,7 +2647,7 @@ def execute_model( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -2659,7 +2658,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("gpu_model_runner: postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2756,12 +2755,12 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - with record_function_or_nullcontext("Sample"): + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): + with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2799,7 +2798,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2826,37 +2825,41 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("EPLB"): + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() - - output = ModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) if not self.use_async_scheduling: return output - - async_output = AsyncGPUModelRunnerOutput( - model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, - logprobs_tensors=sampler_output.logprobs_tensors, - invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, - ) - - # Save ref of sampled_token_ids CPU tensor if the batch contains - # any requests with sampling params that that require output ids. - self.input_batch.set_async_sampled_token_ids( - async_output.sampled_token_ids_cpu, - async_output.async_copy_ready_event, - ) + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) return async_output diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 19061fcffdf1..0bc9aa5ee863 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -401,12 +401,27 @@ def compile_or_warm_up_model(self) -> None: # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: + + if ( + not self.model_config.enforce_eager + or self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): warmup_sizes = [ x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the start of the range to ensure compilation/warmup. + all_sizes = set(self.vllm_config.compilation_config.cudagraph_capture_sizes) + all_sizes.update(warmup_sizes) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end - 1) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 0ca7e81a5c7b..072558a5751c 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -340,7 +340,7 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes