From c3af2af0b8be65ecd1a8538bcfb9622e873e6b3c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 07:25:57 -0700 Subject: [PATCH 01/49] Split PR. Second part. Compile ranges Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 86 ++++++++++++++ vllm/compilation/backends.py | 104 +++++++--------- vllm/compilation/collective_fusion.py | 144 +++++++++-------------- vllm/compilation/compiler_interface.py | 40 ++++--- vllm/compilation/inductor_pass.py | 11 +- vllm/compilation/pass_manager.py | 4 +- vllm/compilation/piecewise_backend.py | 57 +++++---- vllm/compilation/sequence_parallelism.py | 6 +- vllm/config/compilation.py | 33 ++++++ 9 files changed, 288 insertions(+), 197 deletions(-) create mode 100644 tests/compile/test_compile_ranges.py diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..6759da199f4b --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@support_torch_compile +class TestModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE).cuda()) + + +def test_compile_ranges(): + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + )) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 53fd5e74dc0a..686c415f7ac3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,7 +80,8 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[tuple[int, int] | None, int, str], + Any] = (dict()) self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -89,11 +90,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: tuple[int, int] | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: inductor_partition_ops = resolve_defined_ops( self.compilation_config.splitting_ops @@ -150,29 +151,25 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape - ) - if runtime_shape is None: + handle = self.cache[(compile_range, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, compile_range) + if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, - ) + "Directly load the %s-th graph for compile range %s from %s via " + "handle %s", graph_index, str(compile_range), + self.compiler.name, handle) return compiled_graph def compile( @@ -183,7 +180,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -195,15 +192,15 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, + compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " "from the cache, took %.3f s", @@ -211,11 +208,9 @@ def compile( ) else: logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", - str(runtime_shape), - elapsed, - ) + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", str(compile_range), + elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -224,48 +219,40 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): - compiled_graph, handle = self.compiler.compile( - graph, - example_inputs, - additional_inductor_config, - runtime_shape, - maybe_key, - ) + maybe_key = \ + f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, compile_range, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, + self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) + "Cache the graph for dynamic shape for later use") else: - logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), - scope="local", - ) - if runtime_shape is None: + logger.info_once("Cache the graph of compile range %s for later use", + str(compile_range)) + if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", + "Store the %s-th graph for dynamic compile range from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", + "Store the %s-th graph for compile range %s from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -275,19 +262,16 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", + elapsed, scope="local", ) else: - logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, - elapsed, - scope="local", - ) + logger.info_once("Compiling a graph for compile range %s takes %.2f s", + str(compile_range), elapsed, scope="local") return compiled_graph @@ -408,7 +392,6 @@ def call_module( i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_dynamic_shape = ( self.vllm_backend.compiler_manager.compile( submod, @@ -417,9 +400,8 @@ def call_module( self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) + compile_range=None, + )) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cf89182357f2..a4758c971611 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -504,93 +504,59 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - - if num_tokens <= max_token_num: - device_capability = ( - current_platform.get_device_capability().as_version_str() - ) - # Get one shot input size limit for the current world size - # for the current device capability - max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} - ).get(world_size, None) - # Use one shot if no max size for one shot is specified - use_oneshot = ( - max_one_shot_size_mb is None - or current_tensor_size <= max_one_shot_size_mb * MiB - ) - - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, \ + f"Current tensor size {current_tensor_size} is larger than " \ + f"max token num {max_token_num} * hidden size {hidden_size} * " \ + f"element size {element_size}" + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size * MiB + + assert ( + _FI_WORKSPACE_TENSOR + is not None), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None and scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -1212,6 +1178,12 @@ def register_patterns(self): self.disabled = False @VllmInductorPass.time_and_log + def is_applicable_for_range( + self, compile_range: tuple[int, int] | None) -> bool: + if compile_range is None: + return False + return compile_range[1] - 1 <= self.max_token_num + def __call__(self, graph: fx.Graph): if self.disabled: logger.debug("AllReduceFusionPass disabled") diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db94..3861bfed11d5 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,16 +63,17 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means + with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + `compile_range` specifies the range of the inputs, + it could be concrete size, e.g. (4, 4). + Right now we only support one variable range of shapes for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: """ Load the compiled function from the handle. @@ -192,18 +193,21 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): - dynamic_shapes = "from_example_inputs" + if isinstance(compile_range, tuple): + if compile_range[0] == compile_range[1]: + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_graph" else: dynamic_shapes = "from_tracing_context" @@ -230,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -294,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -308,7 +312,7 @@ def compile( current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -493,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -589,9 +593,9 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range): + if isinstance(compile_range, tuple): + # for a specific range of batchsizes, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -611,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..1b4430c82b2d 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -28,8 +28,8 @@ class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: tuple[int, int] | None): + self.compile_range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +39,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: tuple[int, int] | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +96,8 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: tuple[int, + int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 3bc35a8f7198..82bca8f1fe1b 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -69,9 +69,9 @@ def __init__(self): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..87b0121f43cb 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,7 +7,6 @@ import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig @@ -17,8 +16,8 @@ @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: tuple[int, int] compiled: bool = False runnable: Callable = None # type: ignore @@ -55,7 +54,12 @@ def __init__( self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) + + self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ + 0] < range[1] else x == range[0] self.first_run_finished = False @@ -63,24 +67,27 @@ def __init__( self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + # self.concrete_size_entries: dict[int, RangeEntry] = {} + + # the entries for ranges that we need to either + # TODO: we should merge with concrete_size_entries + self.range_entries: dict[tuple[int, int], RangeEntry] = {} - # to_be_compiled_sizes tracks the remaining sizes to compile, + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[tuple[int, + int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() @@ -94,28 +101,32 @@ def __call__(self, *args) -> Any: runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: + range_entry = None + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + break + + if (range_entry is None): # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) - entry = self.concrete_size_entries[runtime_shape] + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, - ) + compile_range=range_entry.compile_range) # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): self.check_for_ending_compilation() - return entry.runnable(*args) + return range_entry.runnable(*args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..78fd8386f56e 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -482,7 +482,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -502,7 +502,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] + == compile_range[1]) and (compile_range[1] % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 72418762773c..374e1c99fea0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -214,6 +214,8 @@ class CompilationConfig: - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -331,6 +333,16 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: Optional[list[int]] = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]), + [split_points[0], split_points[1]), ..., + [split_points[-1], max_num_batched_tokens + 1). + Compile sizes are also used single element ranges: + [compile_sizes[i], compile_sizes[i] + 1). + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -914,3 +926,24 @@ def custom_op_log_check(self): enable_str, op, ) + + def get_compile_ranges(self) -> list[tuple[int, int]]: + """Get the compile ranges for the compilation config.""" + compile_ranges_split_points = self.compile_ranges_split_points + compile_ranges = [] + # max_num_batched_tokens + 1 + max_split_point = max(compile_ranges_split_points) + compile_sizes = set(self.compile_sizes) + split_points = sorted( + compile_sizes.union(set(self.compile_ranges_split_points))) + # filter out split points that are greater + # than max_num_batched_tokens + 1 + split_points = [x for x in split_points if x <= max_split_point] + for i, s in enumerate(split_points): + if i == 0: + compile_ranges.append((1, s)) + else: + compile_ranges.append((split_points[i - 1], s)) + if s in compile_sizes and s != 1: + compile_ranges.append((s, s)) + return sorted(compile_ranges) From 0cbb0656ac01d60fb3286e63550d215e95caed81 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 10:00:52 -0700 Subject: [PATCH 02/49] Remove general shape graph Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 14 +------ vllm/compilation/piecewise_backend.py | 53 +++++++++++++-------------- vllm/config/compilation.py | 2 + 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 686c415f7ac3..45a1a8c2f267 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -391,17 +391,7 @@ def call_module( sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - compile_range=None, - )) + # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -411,7 +401,7 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, + # compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 87b0121f43cb..d280b85fc82a 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -63,15 +63,12 @@ def __init__( self.first_run_finished = False - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - self.sym_shape_indices = sym_shape_indices # the entries for different shapes that we need to compile # self.concrete_size_entries: dict[int, RangeEntry] = {} # the entries for ranges that we need to either - # TODO: we should merge with concrete_size_entries self.range_entries: dict[tuple[int, int], RangeEntry] = {} # to_be_compiled_ranges tracks the remaining ranges to compile, @@ -81,10 +78,7 @@ def __init__( # We only keep compilation management inside this class directly. for range in self.compile_ranges: - self.range_entries[range] = RangeEntry( - compile_range=range, - runnable=self.compiled_graph_for_general_shape, - ) + self.range_entries[range] = RangeEntry(compile_range=range, ) def check_for_ending_compilation(self): if (self.is_last_graph and not self.to_be_compiled_ranges): @@ -93,24 +87,8 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - - range_entry = None - for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): - range_entry = self.range_entries[range] - break - - if (range_entry is None): - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, + args) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -126,7 +104,28 @@ def __call__(self, *args) -> Any: compile_range=range_entry.compile_range) # finished compilations for all required shapes - if (self.is_last_graph and not self.to_be_compiled_ranges): - self.check_for_ending_compilation() + self.check_for_ending_compilation() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + + # Role of the general is taken by the last range + range_entry = self.range_entries[self.compile_ranges[-1]] + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + + range_entry = None + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + break + assert range_entry is not None, \ + f"Shape out of considered range: {runtime_shape} " \ + "[1, max_num_batched_tokens]" + + self._maybe_compile_for_range_entry(range_entry, args) return range_entry.runnable(*args) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 374e1c99fea0..2aab5cb5f295 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -946,4 +946,6 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges.append((split_points[i - 1], s)) if s in compile_sizes and s != 1: compile_ranges.append((s, s)) + assert compile_ranges[-1][1] == max_split_point, \ + "Last compile range end should be max_split_point" return sorted(compile_ranges) From d5392f54cb6e8f15926f1d89642ad08cda44a99c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 5 Sep 2025 06:00:15 -0700 Subject: [PATCH 03/49] Add test to test pipeline Signed-off-by: ilmarkov --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6cbc25b4b3bf..105eca371ff3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -412,6 +412,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_compile_ranges.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 From 027c9eb348808e1a37c9dbc86fbfcd020e2166a8 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 9 Sep 2025 05:32:05 -0700 Subject: [PATCH 04/49] Fix pre-commit Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index d280b85fc82a..cec8aca63d80 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -117,12 +117,13 @@ def __call__(self, *args) -> Any: runtime_shape = args[self.sym_shape_indices[0]] - range_entry = None + range_found = False for range in self.compile_ranges: if self.is_in_range(runtime_shape, range): range_entry = self.range_entries[range] + range_found = True break - assert range_entry is not None, \ + assert range_found, \ f"Shape out of considered range: {runtime_shape} " \ "[1, max_num_batched_tokens]" From b2992d3b9afa19156df1453fa504df87ecbc30d9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:12:17 +0000 Subject: [PATCH 05/49] Upd Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 48 ++++++++-------- vllm/compilation/backends.py | 73 ++++++++++++++---------- vllm/compilation/collective_fusion.py | 19 +++--- vllm/compilation/compiler_interface.py | 16 +++--- vllm/compilation/inductor_pass.py | 3 +- vllm/compilation/pass_manager.py | 2 +- vllm/compilation/piecewise_backend.py | 30 +++++----- vllm/compilation/sequence_parallelism.py | 8 ++- vllm/config/compilation.py | 8 ++- 9 files changed, 114 insertions(+), 93 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 6759da199f4b..68389ccfbe14 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -6,8 +6,12 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op @@ -18,15 +22,17 @@ MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: out.copy_(q) out += k out += v -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: return @@ -41,12 +47,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @support_torch_compile class TestModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -59,8 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, - batch_sizes: list[int]): +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): with set_forward_context({}, vllm_config=vllm_config): model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) for batch_size in batch_sizes: @@ -68,19 +68,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, def test_compile_ranges(): - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - compile_ranges_split_points=[8, 32], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + ) + ) with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() batch_sizes = [1, 16, 48] # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45a1a8c2f267..beda9b36f686 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,8 +80,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[tuple[int, int] | None, int, str], - Any] = (dict()) + self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -156,20 +155,26 @@ def load( if (compile_range, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(compile_range, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load(handle, graph, example_inputs, - graph_index, compile_range) + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, compile_range + ) if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic compile range from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for compile range %s from %s via " - "handle %s", graph_index, str(compile_range), - self.compiler.name, handle) + "Directly load the %s-th graph for compile range %s" + "from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) return compiled_graph def compile( @@ -192,8 +197,7 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, - compile_range) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -209,8 +213,10 @@ def compile( else: logger.info( "Directly load the compiled graph(s) for compile range %s " - "from the cache, took %.3f s", str(compile_range), - elapsed) + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -219,38 +225,43 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = \ - f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" - compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, compile_range, - maybe_key) + maybe_key = f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + with self.compile_context(compile_range): + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + compile_range, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(compile_range, graph_index, - self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph if compile_range is None: - logger.info_once( - "Cache the graph for dynamic shape for later use") + logger.info_once("Cache the graph for dynamic shape for later use", scope="local") else: - logger.info_once("Cache the graph of compile range %s for later use", - str(compile_range)) + logger.info_once( + "Cache the graph of compile range %s for later use", + str(compile_range), + ) if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic compile range from %s via handle %s", + "Store the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for compile range %s from %s via handle %s", + "Store the %s-th graph for compile range%s from %s via handle %s", graph_index, str(compile_range), self.compiler.name, @@ -264,14 +275,17 @@ def compile( compilation_config.compilation_time += elapsed if compile_range is None: logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", - + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) else: - logger.info_once("Compiling a graph for compile range %s takes %.2f s", - str(compile_range), elapsed, scope="local") + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph @@ -401,7 +415,6 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - # compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a4758c971611..3d970ac2964b 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -505,12 +505,12 @@ def call_trtllm_fused_allreduce_norm( element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size max_tensor_size = max_token_num * hidden_size * element_size - assert current_tensor_size <= max_tensor_size, \ - f"Current tensor size {current_tensor_size} is larger than " \ - f"max token num {max_token_num} * hidden size {hidden_size} * " \ + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " f"element size {element_size}" - device_capability = current_platform.get_device_capability( - ).as_version_str() + ) + device_capability = current_platform.get_device_capability().as_version_str() # Get one shot input size limit for the current world size # for the current device capability max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ @@ -520,9 +520,9 @@ def call_trtllm_fused_allreduce_norm( use_oneshot = max_one_shot_size is None or \ current_tensor_size <= max_one_shot_size * MiB - assert ( - _FI_WORKSPACE_TENSOR - is not None), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -1178,8 +1178,7 @@ def register_patterns(self): self.disabled = False @VllmInductorPass.time_and_log - def is_applicable_for_range( - self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: if compile_range is None: return False return compile_range[1] - 1 <= self.max_token_num diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3861bfed11d5..4e5aa077ddae 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,14 +63,14 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `compile_range` specifies the range of the inputs, + `compile_range` specifies the range of the inputs, it could be concrete size, e.g. (4, 4). Right now we only support one variable range of shapes for all inputs, which is the batchsize (number of tokens) during inference. @@ -99,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -193,7 +193,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -234,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -298,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -497,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -615,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 1b4430c82b2d..599fa776b6c0 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -96,8 +96,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: tuple[int, - int] | None): + def is_applicable_for_range(self, compile_range: tuple[int, int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 82bca8f1fe1b..08002dc862f6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -75,7 +75,7 @@ def __call__(self, graph: fx.Graph): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index cec8aca63d80..607d6a80f5cf 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -30,7 +30,6 @@ def __init__( piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -58,8 +57,11 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) - self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ - 0] < range[1] else x == range[0] + self.is_in_range = ( + lambda x, range: range[0] <= x < range[1] + if range[0] < range[1] + else x == range[0] + ) self.first_run_finished = False @@ -73,22 +75,22 @@ def __init__( # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[tuple[int, - int]] = set(self.compile_ranges) + self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. for range in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) def check_for_ending_compilation(self): - if (self.is_last_graph and not self.to_be_compiled_ranges): + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, - args) -> Any: + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -101,7 +103,8 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - compile_range=range_entry.compile_range) + compile_range=range_entry.compile_range, + ) # finished compilations for all required shapes self.check_for_ending_compilation() @@ -123,9 +126,10 @@ def __call__(self, *args) -> Any: range_entry = self.range_entries[range] range_found = True break - assert range_found, \ - f"Shape out of considered range: {runtime_shape} " \ - "[1, max_num_batched_tokens]" + assert range_found, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) self._maybe_compile_for_range_entry(range_entry, args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 78fd8386f56e..cf47adb4670a 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -502,9 +502,11 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool ): return True tp_size = get_tensor_model_parallel_world_size() - return compile_range is not None and ( - compile_range[0] - == compile_range[1]) and (compile_range[1] % tp_size == 0) + return ( + compile_range is not None + and (compile_range[0] == compile_range[1]) + and (compile_range[1] % tp_size == 0) + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2aab5cb5f295..278fe5801323 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -333,7 +333,7 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" - compile_ranges_split_points: Optional[list[int]] = None + compile_ranges_split_points: list[int] | None = None """Split points that represent compile ranges for inductor. The compile ranges are [1, split_points[0]), @@ -935,7 +935,8 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: max_split_point = max(compile_ranges_split_points) compile_sizes = set(self.compile_sizes) split_points = sorted( - compile_sizes.union(set(self.compile_ranges_split_points))) + compile_sizes.union(set(self.compile_ranges_split_points)) + ) # filter out split points that are greater # than max_num_batched_tokens + 1 split_points = [x for x in split_points if x <= max_split_point] @@ -946,6 +947,7 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges.append((split_points[i - 1], s)) if s in compile_sizes and s != 1: compile_ranges.append((s, s)) - assert compile_ranges[-1][1] == max_split_point, \ + assert compile_ranges[-1][1] == max_split_point, ( "Last compile range end should be max_split_point" + ) return sorted(compile_ranges) From 3499384c1e183cd851c93d12ea7d77c08de03ed2 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:32:36 +0000 Subject: [PATCH 06/49] Upd config Signed-off-by: ilmarkov --- vllm/config/vllm.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 916f258d6586..fd38992e374b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -426,6 +426,8 @@ def __post_init__(self): "correctness and to realize prefill savings. " ) + self._set_compile_ranges() + disable_chunked_prefill_reasons: list[str] = [] if self.model_config: @@ -796,6 +798,49 @@ def _set_cudagraph_sizes(self): # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + # We add 1 because the bounds checks in the compiler are exclusive + # and we want to include the max_num_batched_tokens + # in the compile range + computed_compile_ranges_split_points.append(max_num_batched_tokens + 1) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.enable_fi_allreduce_fusion: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + # We add 1 because the bounds checks in the compiler are + # exclusive and we want to include the max_token_num in the + # compile range + computed_compile_ranges_split_points.append(max_token_num + 1) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) # type: ignore + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config From 5336ee6ffe1d5b03b69b23f4b346ba10a549c6cd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:51:01 +0000 Subject: [PATCH 07/49] Fix Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 18 ++++++++++-------- vllm/v1/worker/utils.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 3d970ac2964b..7c0a1208d870 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -431,7 +431,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -441,7 +441,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -1100,18 +1102,18 @@ def __init__(self, config: VllmConfig): ) return element_size = 4 if use_fp32_lamport else 2 - max_token_num = max_size // (self.hidden_dim * element_size) + self.max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways - max_token_num = min( - max_token_num, config.scheduler_config.max_num_batched_tokens + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_token_num, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1124,7 +1126,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_token_num, + max_token_num=self.max_token_num, ) self.register_patterns() @@ -1177,12 +1179,12 @@ def register_patterns(self): self.disabled = False - @VllmInductorPass.time_and_log def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: if compile_range is None: return False return compile_range[1] - 1 <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: logger.debug("AllReduceFusionPass disabled") diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92baf0cb7136..ef953dd2051e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -330,7 +330,7 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes From 4958474f77a930f532730a9ec7a395339ea32138 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 17 Oct 2025 11:30:21 +0000 Subject: [PATCH 08/49] Priotitize compile_sizes Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 28 ++++++++++++++++++++------- vllm/config/compilation.py | 18 ++--------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 607d6a80f5cf..7a10fed1d237 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -57,6 +57,10 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) + self.is_in_range = ( lambda x, range: range[0] <= x < range[1] if range[0] < range[1] @@ -78,6 +82,12 @@ def __init__( self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. + for size in self.compile_sizes: + range = (size, size) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + for range in self.compile_ranges: self.range_entries[range] = RangeEntry( compile_range=range, @@ -112,20 +122,24 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True + self.check_for_ending_compilation() - # Role of the general is taken by the last range + # Role of the general graph is taken by the last range graph range_entry = self.range_entries[self.compile_ranges[-1]] self._maybe_compile_for_range_entry(range_entry, args) return range_entry.runnable(*args) - runtime_shape = args[self.sym_shape_indices[0]] range_found = False - for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): - range_entry = self.range_entries[range] - range_found = True - break + if runtime_shape in self.compile_sizes: + range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_found = True + else: + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + range_found = True + break assert range_found, ( f"Shape out of considered range: {runtime_shape} " "[1, max_num_batched_tokens]" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 278fe5801323..c2a6d6d783b9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -929,25 +929,11 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - compile_ranges_split_points = self.compile_ranges_split_points + split_points = self.compile_ranges_split_points compile_ranges = [] - # max_num_batched_tokens + 1 - max_split_point = max(compile_ranges_split_points) - compile_sizes = set(self.compile_sizes) - split_points = sorted( - compile_sizes.union(set(self.compile_ranges_split_points)) - ) - # filter out split points that are greater - # than max_num_batched_tokens + 1 - split_points = [x for x in split_points if x <= max_split_point] for i, s in enumerate(split_points): if i == 0: compile_ranges.append((1, s)) else: compile_ranges.append((split_points[i - 1], s)) - if s in compile_sizes and s != 1: - compile_ranges.append((s, s)) - assert compile_ranges[-1][1] == max_split_point, ( - "Last compile range end should be max_split_point" - ) - return sorted(compile_ranges) + return compile_ranges From 04306ed0dacf3fc11bcfb5ae993095d8d5a506bb Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 28 Oct 2025 13:26:59 +0000 Subject: [PATCH 09/49] Fix inductor config Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 7 ++++++- vllm/compilation/compiler_interface.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index beda9b36f686..30ab91e4ab82 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -225,7 +225,12 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + maybe_key = "artifact_compile_range_" + if compile_range is None: + maybe_key += "dynamic_shape" + else: + maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 4e5aa077ddae..d069769fe76f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -594,8 +594,8 @@ def metrics_context(self) -> contextlib.AbstractContextManager: def set_inductor_config(config, compile_range): - if isinstance(compile_range, tuple): - # for a specific range of batchsizes, tuning triton kernel parameters + if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( From 9dc4eea25b0ec2520d920616002a6f148a1c3801 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 3 Nov 2025 10:53:49 +0000 Subject: [PATCH 10/49] Laith's fix Signed-off-by: ilmarkov --- vllm/compilation/compiler_interface.py | 38 +++++++++++++++++++++----- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index d069769fe76f..3453b8f676e8 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,13 +213,37 @@ def compile( from torch._inductor import standalone_compile - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) - + if dynamic_shapes == "from_graph": + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # standalone_compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(example_inputs) + compiled_graph = standalone_compile( + graph, + fake_example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) + else: + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) From 2c63f0b05c02ce4d93e23093b3838af775d92614 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 10:22:17 +0000 Subject: [PATCH 11/49] Upd Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 6 ++++-- vllm/compilation/collective_fusion.py | 11 ++++++----- vllm/config/compilation.py | 3 +++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 30ab91e4ab82..7cda5d0dee96 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -250,7 +250,9 @@ def compile( if graph_index == 0: # adds some info logging for the first graph if compile_range is None: - logger.info_once("Cache the graph for dynamic shape for later use", scope="local") + logger.info_once( + "Cache the graph for dynamic shape for later use", scope="local" + ) else: logger.info_once( "Cache the graph of compile range %s for later use", @@ -280,7 +282,7 @@ def compile( compilation_config.compilation_time += elapsed if compile_range is None: logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7c0a1208d870..9c20db07c267 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -515,12 +515,13 @@ def call_trtllm_fused_allreduce_norm( device_capability = current_platform.get_device_capability().as_version_str() # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ - get(device_capability, {}). \ - get(world_size, None) + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) # Use one shot if no max size is specified - use_oneshot = max_one_shot_size is None or \ - current_tensor_size <= max_one_shot_size * MiB + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB + ) assert _FI_WORKSPACE_TENSOR is not None, ( "Flashinfer must be enabled when using flashinfer" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c2a6d6d783b9..e469c8e25a43 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -142,6 +142,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: max_sizes = { k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() } + logger.debug_once( + f"flashinfer_max_size: {max_sizes.get(world_size)}", scope="global" + ) # return None if world size is not supported by flashinfer return max_sizes.get(world_size) From fcebc21fb1708abbfc2622cfeee517aef801c622 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 14:30:18 +0000 Subject: [PATCH 12/49] Add caching Signed-off-by: ilmarkov --- vllm/compilation/compiler_interface.py | 37 +++++--------------------- vllm/compilation/pass_manager.py | 1 + vllm/compilation/piecewise_backend.py | 23 +++++++++++++++- vllm/config/compilation.py | 8 +++--- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3453b8f676e8..6a57cd4bc578 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,37 +213,12 @@ def compile( from torch._inductor import standalone_compile - if dynamic_shapes == "from_graph": - # We need to pass fake example_inputs, otherwise torch.compile - # will fakify the example_inputs potentially causing some non dynamic - # dimension to be be duck shaped to other existing shapes that have hints - # matching their values. - # This is problem because it can lead to unintended specializations! - # if the new wrongly dynamic dim is specialized - # it will force specializing the whole shape - # standalone_compile probably should not accept - # non fake tensors as example inputs! - fake_example_inputs = [] - for node in graph.graph.nodes: - # All place holders come first - if node.op == "placeholder": - fake_example_inputs.append(node.meta["example_value"]) - else: - break - assert len(fake_example_inputs) == len(example_inputs) - compiled_graph = standalone_compile( - graph, - fake_example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) - else: - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 08002dc862f6..3e0c9bc99a24 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -155,5 +155,6 @@ def uuid(self): # See [HACK: Bug with Inductor graph partition and torch.compile cache] state["inductor_splitting_ops"].extend(self.inductor_splitting_ops) + state["compile_range"] = get_pass_context().compile_range return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 7a10fed1d237..ad5b49f28550 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -87,6 +87,7 @@ def __init__( self.range_entries[range] = RangeEntry( compile_range=range, ) + self.to_be_compiled_ranges.add(range) for range in self.compile_ranges: self.range_entries[range] = RangeEntry( @@ -100,6 +101,26 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def fakify_args(self, args: list[Any]) -> list[Any]: + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # torch.compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in self.graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(args) + return fake_example_inputs + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: if not range_entry.compiled: range_entry.compiled = True @@ -108,7 +129,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # args are real arguments range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - args, + self.fakify_args(args), self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 475f4c15afef..fa728c23d145 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -142,11 +142,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: max_size_mb = self.fi_allreduce_fusion_max_size_mb if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) - logger.debug_once( - f"flashinfer_max_size: {int(max_size_mb * MiB)}", scope="global" - ) - return int(max_size_mb * MiB) - return None + max_size_bytes = int(max_size_mb * MiB) if max_size_mb is not None else None + logger.debug_once(f"flashinfer_max_size: {max_size_bytes}", scope="global") + return max_size_bytes @staticmethod def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: From 65151bcecf8429890f4fa191e7988aedfb2c9aa5 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 12:58:20 +0000 Subject: [PATCH 13/49] Address comments Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 65 +++++++++++++++------------ vllm/compilation/collective_fusion.py | 5 +++ vllm/config/compilation.py | 1 - 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 68389ccfbe14..03f31df1ece7 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -1,19 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from torch import fx as fx from torch import nn from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.inductor_pass import ( + CustomGraphPass, + InductorPass, + get_pass_context, +) from vllm.config import ( - CompilationConfig, - CompilationLevel, VllmConfig, set_current_vllm_config, ) +from vllm.config.compilation import CompilationConfig, CompilationMode +from vllm.config.scheduler import SchedulerConfig from vllm.forward_context import set_forward_context -from vllm.utils import direct_register_custom_op # create a library to hold the custom op silly_lib = Library("silly", "FRAGMENT") # noqa @@ -22,29 +27,6 @@ MLP_SIZE = 128 -def silly_attention( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor -) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor -) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class TestModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: @@ -67,12 +49,37 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE).cuda()) +class PostGradPassManagerCheckRanges(CustomGraphPass): + def __init__(self, ranges: list[tuple[int, int]]): + self.ranges = ranges + + def __call__(self, graph: fx.Graph): + compile_range = get_pass_context().compile_range + assert compile_range in self.ranges, ( + f"Compile range {compile_range} not in {self.ranges}" + ) + + def uuid(self) -> str: + state = { + "ranges": self.ranges, + } + return InductorPass.hash_dict(state) + + def test_compile_ranges(): vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], - ) + ), + inductor_compile_config={ + "post_grad_custom_post_pass": PostGradPassManagerCheckRanges( + [(1, 8), (8, 32), (32, 2049)] + ) + }, ) with set_current_vllm_config(vllm_config): @@ -82,7 +89,7 @@ def test_compile_ranges(): with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, - num_backend_compilations=4, + num_backend_compilations=3, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 9c20db07c267..aaf53c6e5768 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1109,6 +1109,11 @@ def __init__(self, config: VllmConfig): self.max_token_num = min( self.max_token_num, config.scheduler_config.max_num_batched_tokens ) + logger.debug_once( + f"Flashinfer max size: {max_size // (1024 * 1024)} MB" + f", Maximal number of tokens: {self.max_token_num}", + scope="global", + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index fa728c23d145..6e50493a770c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -143,7 +143,6 @@ def flashinfer_max_size(self, world_size: int) -> int | None: if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) max_size_bytes = int(max_size_mb * MiB) if max_size_mb is not None else None - logger.debug_once(f"flashinfer_max_size: {max_size_bytes}", scope="global") return max_size_bytes @staticmethod From df22202272995c4a9c99f1ae7c562416d9620e53 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 11:25:17 -0500 Subject: [PATCH 14/49] Update benchmark Signed-off-by: ilmarkov --- benchmarks/kernels/benchmark_fused_collective.py | 16 ++++++++++++---- vllm/config/compilation.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index cec134ff9138..d7fa0580a3e7 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -410,6 +410,7 @@ def run_benchmarks( use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, quant_modes: set[str], + no_oneshot: bool, ): """Run all benchmarks for given configuration. @@ -431,6 +432,7 @@ def run_benchmarks( rms_eps = 1e-6 results = {} vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] # Create RMSNorm and QuantFP8 layers once for native benchmarks @@ -476,7 +478,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -560,7 +562,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -645,7 +647,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -901,7 +903,7 @@ def save_results_to_file( try: markdown_content = format_results_markdown(all_results, world_size, args) - with open(output_path, "w") as f: + with open(output_path, "a") as f: f.write(markdown_content) except Exception as e: @@ -960,6 +962,12 @@ def main(): """, ) + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + args = parser.parse_args() # Check if running with torchrun (required for collective operations) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 6e50493a770c..6f35673856df 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -923,7 +923,7 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - split_points = self.compile_ranges_split_points + split_points = set(self.compile_ranges_split_points) compile_ranges = [] for i, s in enumerate(split_points): if i == 0: From a21de2baef2202f2610788027c904f9b377752e9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 16:32:59 +0000 Subject: [PATCH 15/49] Fix Signed-off-by: ilmarkov --- benchmarks/kernels/benchmark_fused_collective.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index d7fa0580a3e7..99213d0c7cc2 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -1076,6 +1076,7 @@ def main(): use_residual, allreduce_params, quant_modes=quant_modes, + no_oneshot=args.no_oneshot, ) # Store results for markdown export From 6766e4f7da7914d7b1a24e6d760f56e181d5fbaa Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 17:15:46 -0500 Subject: [PATCH 16/49] Update fakify for compile sizes Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 9 ++++++++- vllm/config/compilation.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index ad5b49f28550..fe35aaa9e4ae 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -122,14 +122,21 @@ def fakify_args(self, args: list[Any]) -> list[Any]: return fake_example_inputs def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + is_compile_size = lambda range: range[0] == range[1] if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) # args are real arguments + # fakify for range, real args for concrete size + args = ( + self.fakify_args(args) + if not is_compile_size(range_entry.compile_range) + else args + ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - self.fakify_args(args), + args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 740b970669ed..67cd974a13e7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -947,7 +947,7 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - split_points = set(self.compile_ranges_split_points) + split_points = sorted(set(self.compile_ranges_split_points)) compile_ranges = [] for i, s in enumerate(split_points): if i == 0: From af87d7a7996dc857933ce38b8be3badbed95a935 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 6 Nov 2025 09:59:37 -0500 Subject: [PATCH 17/49] Linter fix Signed-off-by: ilmarkov --- vllm/config/compilation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 67cd974a13e7..3a3fdd7f295d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -947,6 +947,8 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] split_points = sorted(set(self.compile_ranges_split_points)) compile_ranges = [] for i, s in enumerate(split_points): From b4c1b1d66d6ce3288c65c57251d0492f2e9f475b Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 10 Nov 2025 12:31:48 +0000 Subject: [PATCH 18/49] Address the review Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 50 +++++++++++++----------- vllm/compilation/backends.py | 12 +++--- vllm/compilation/collective_fusion.py | 9 +++-- vllm/compilation/compiler_interface.py | 21 +++++----- vllm/compilation/inductor_pass.py | 9 +++-- vllm/compilation/pass_manager.py | 4 +- vllm/compilation/piecewise_backend.py | 27 ++++++------- vllm/compilation/sequence_parallelism.py | 7 ++-- vllm/config/compilation.py | 8 ++-- vllm/config/utils.py | 36 ++++++++++++++++- vllm/config/vllm.py | 6 ++- vllm/v1/worker/gpu_worker.py | 19 ++++++++- 12 files changed, 137 insertions(+), 71 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 03f31df1ece7..564690f18192 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -3,12 +3,11 @@ import torch from torch import fx as fx from torch import nn -from torch.library import Library +import tests.compile.silly_attention # noqa from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.inductor_pass import ( - CustomGraphPass, InductorPass, get_pass_context, ) @@ -18,11 +17,9 @@ ) from vllm.config.compilation import CompilationConfig, CompilationMode from vllm.config.scheduler import SchedulerConfig +from vllm.config.utils import Range from vllm.forward_context import set_forward_context -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - BATCH_SIZE = 64 MLP_SIZE = 128 @@ -49,24 +46,34 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE).cuda()) -class PostGradPassManagerCheckRanges(CustomGraphPass): - def __init__(self, ranges: list[tuple[int, int]]): +class PostGradPassManagerCheckRanges(InductorPass): + def __init__(self, ranges: list[Range]): self.ranges = ranges + self.num_calls = 0 def __call__(self, graph: fx.Graph): compile_range = get_pass_context().compile_range assert compile_range in self.ranges, ( f"Compile range {compile_range} not in {self.ranges}" ) + self.num_calls += 1 def uuid(self) -> str: state = { - "ranges": self.ranges, + "ranges": [str(range) for range in self.ranges], + "current_compile_range": str(get_pass_context().compile_range), } return InductorPass.hash_dict(state) def test_compile_ranges(): + post_grad_pass_manager = PostGradPassManagerCheckRanges( + [ + Range(start=1, end=8), + Range(start=8, end=32), + Range(start=32, end=8193), + ] + ) vllm_config = VllmConfig( scheduler_config=SchedulerConfig( max_num_batched_tokens=8192, @@ -74,22 +81,21 @@ def test_compile_ranges(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_pass_manager + }, ), - inductor_compile_config={ - "post_grad_custom_post_pass": PostGradPassManagerCheckRanges( - [(1, 8), (8, 32), (32, 2049)] - ) - }, ) with set_current_vllm_config(vllm_config): model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() - batch_sizes = [1, 16, 48] - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=3, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ): - run_model(vllm_config, model, batch_sizes) + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=3, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_pass_manager.num_calls == 3 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7a1d851ebe42..0d7ef88c8e6a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -22,6 +22,7 @@ resolve_defined_ops, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.utils import Range from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -83,7 +84,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() + self.cache: dict[tuple[Range | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -92,7 +93,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, compile_range: tuple[int, int] | None = None): + def compile_context(self, compile_range: Range | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" @@ -152,7 +153,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None @@ -187,7 +188,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -206,6 +207,7 @@ def compile( # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " @@ -231,7 +233,7 @@ def compile( if compile_range is None: maybe_key += "dynamic_shape" else: - maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index dbe17f984808..81e881373e45 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -431,7 +432,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -442,7 +443,7 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool return True tp_size = get_tensor_model_parallel_world_size() return compile_range is not None and ( - compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + compile_range.is_single_size() and compile_range.end % tp_size == 0 ) @VllmInductorPass.time_and_log @@ -1188,10 +1189,10 @@ def register_patterns(self): self.disabled = False - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: if compile_range is None: return False - return compile_range[1] - 1 <= self.max_token_num + return compile_range.end - 1 <= self.max_token_num @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 6124a5428f6c..b95067aba191 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -63,7 +64,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ @@ -99,7 +100,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -213,7 +214,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -223,8 +224,8 @@ def compile( set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(compile_range, tuple): - if compile_range[0] == compile_range[1]: + if compile_range is not None: + if compile_range.is_single_size(): dynamic_shapes = "from_example_inputs" else: dynamic_shapes = "from_graph" @@ -254,7 +255,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -318,7 +319,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -515,7 +516,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -612,7 +613,7 @@ def metrics_context(self) -> contextlib.AbstractContextManager: def set_inductor_config(config, compile_range): - if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + if compile_range is not None and compile_range.is_single_size(): # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE @@ -633,7 +634,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 599fa776b6c0..008eba4629a3 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,6 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): @@ -28,8 +29,8 @@ class PassContext: - def __init__(self, compile_range: tuple[int, int] | None): - self.compile_range = compile_range + def __init__(self, compile_range: Range | None): + self.compile_range: Range | None = compile_range def get_pass_context() -> PassContext: @@ -39,7 +40,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(compile_range: tuple[int, int] | None): +def pass_context(compile_range: Range | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -96,7 +97,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: tuple[int, int] | None): + def is_applicable_for_range(self, compile_range: Range | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 5984f968da35..4664d0d9aefd 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -127,6 +127,8 @@ def uuid(self): for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) - state["compile_range"] = get_pass_context().compile_range + compile_range = get_pass_context().compile_range + if compile_range is not None: + state["compile_range"] = str(compile_range) return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index fe35aaa9e4ae..10844b69c455 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -10,6 +10,7 @@ from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import init_logger logger = init_logger(__name__) @@ -17,7 +18,7 @@ @dataclasses.dataclass class RangeEntry: - compile_range: tuple[int, int] + compile_range: Range compiled: bool = False runnable: Callable = None # type: ignore @@ -61,12 +62,6 @@ def __init__( log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" logger.debug_once(log_string) - self.is_in_range = ( - lambda x, range: range[0] <= x < range[1] - if range[0] < range[1] - else x == range[0] - ) - self.first_run_finished = False self.sym_shape_indices = sym_shape_indices @@ -75,15 +70,15 @@ def __init__( # self.concrete_size_entries: dict[int, RangeEntry] = {} # the entries for ranges that we need to either - self.range_entries: dict[tuple[int, int], RangeEntry] = {} + self.range_entries: dict[Range, RangeEntry] = {} # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. for size in self.compile_sizes: - range = (size, size) + range = Range(start=size, end=size) self.range_entries[range] = RangeEntry( compile_range=range, ) @@ -122,7 +117,6 @@ def fakify_args(self, args: list[Any]) -> list[Any]: return fake_example_inputs def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: - is_compile_size = lambda range: range[0] == range[1] if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -131,7 +125,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # fakify for range, real args for concrete size args = ( self.fakify_args(args) - if not is_compile_size(range_entry.compile_range) + if not range_entry.compile_range.is_single_size() else args ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( @@ -158,13 +152,18 @@ def __call__(self, *args) -> Any: return range_entry.runnable(*args) runtime_shape = args[self.sym_shape_indices[0]] + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. range_found = False if runtime_shape in self.compile_sizes: - range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_entry = self.range_entries[ + Range(start=runtime_shape, end=runtime_shape) + ] range_found = True else: for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): + if range.contains(runtime_shape): range_entry = self.range_entries[range] range_found = True break diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index cf47adb4670a..6a5ee5a0efb7 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -482,7 +483,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -504,8 +505,8 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool tp_size = get_tensor_model_parallel_world_size() return ( compile_range is not None - and (compile_range[0] == compile_range[1]) - and (compile_range[1] % tp_size == 0) + and (compile_range.is_single_size()) + and (compile_range.end % tp_size == 0) ) @VllmInductorPass.time_and_log diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2ae93c59ddfb..298fe4242a83 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -14,7 +14,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config +from vllm.config.utils import Range, config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -945,7 +945,7 @@ def custom_op_log_check(self): op, ) - def get_compile_ranges(self) -> list[tuple[int, int]]: + def get_compile_ranges(self) -> list[Range]: """Get the compile ranges for the compilation config.""" if self.compile_ranges_split_points is None: return [] @@ -953,7 +953,7 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges = [] for i, s in enumerate(split_points): if i == 0: - compile_ranges.append((1, s)) + compile_ranges.append(Range(start=1, end=s)) else: - compile_ranges.append((split_points[i - 1], s)) + compile_ranges.append(Range(start=split_points[i - 1], end=s)) return compile_ranges diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 7e0878d96bbd..7270caf02740 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -6,7 +6,7 @@ import inspect import textwrap from collections.abc import Iterable -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -176,3 +176,37 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: ) processed_overrides[field_name] = value return replace(config, **processed_overrides) + + +@dataclass +class Range: + """ + A range of numbers. + Inclusive of start, exclusive of end. + """ + + start: int + end: int + + def is_single_size(self) -> bool: + return self.start == self.end + + def contains(self, size: int) -> bool: + # Inclusive of start, exclusive of end + if self.is_single_size(): + return size == self.start + return self.start <= size < self.end + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Range): + return False + return self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.start, self.end)) + + def __str__(self) -> str: + return f"(start={self.start}, end={self.end})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 43a3b51b3a0a..a217b3c48f81 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -889,7 +889,11 @@ def _set_compile_ranges(self): # We add 1 because the bounds checks in the compiler are # exclusive and we want to include the max_token_num in the # compile range - computed_compile_ranges_split_points.append(max_token_num + 1) + if ( + max_num_batched_tokens is not None + and max_token_num < max_num_batched_tokens + ): + computed_compile_ranges_split_points.append(max_token_num + 1) if compilation_config.compile_ranges_split_points is not None: for x in compilation_config.compile_ranges_split_points: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f13ff4e726bd..42f9bdeab97e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -398,12 +398,27 @@ def compile_or_warm_up_model(self) -> None: # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: + + if ( + not self.model_config.enforce_eager + or self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): warmup_sizes = [ x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the start of the range to ensure compilation/warmup. + all_sizes = set(self.vllm_config.compilation_config.cudagraph_capture_sizes) + all_sizes.update(warmup_sizes) + for compile_range in compile_ranges: + if not any(compile_range.contains(x) for x in all_sizes): + warmup_sizes.append(compile_range.start) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) From f080a83511511a9c0a222451a752a1623aec095d Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 10 Nov 2025 17:20:53 +0100 Subject: [PATCH 19/49] [RFC][ROCm][AITER] Keep all AITER kernels in `_aiter_ops` class like `_custom_ops` and `_ipex_ops` (#24490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: vllmellm Co-authored-by: Luka Govedič --- docs/design/moe_kernel_features.md | 2 +- tests/kernels/moe/test_moe.py | 11 +- .../model_executor/test_enabled_custom_ops.py | 41 +- vllm/_aiter_ops.py | 941 ++++++++++++++++++ vllm/attention/ops/rocm_aiter_mla.py | 105 -- vllm/envs.py | 8 +- .../layers/fused_moe/fused_moe.py | 15 +- vllm/model_executor/layers/fused_moe/layer.py | 83 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 329 +----- vllm/model_executor/layers/layernorm.py | 90 +- .../compressed_tensors_moe.py | 12 +- .../schemes/compressed_tensors_w8a8_fp8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 16 +- .../quantization/kernels/scaled_mm/aiter.py | 48 +- .../layers/quantization/quark/quark_moe.py | 47 +- .../quark/schemes/quark_ocp_mx.py | 7 + .../layers/quantization/utils/fp8_utils.py | 124 +-- .../layers/quantization/utils/w8a8_utils.py | 2 +- .../layers/rotary_embedding/base.py | 13 +- .../rotary_embedding/deepseek_scaling_rope.py | 9 + .../rotary_embedding/rocm_aiter_rope_ops.py | 94 -- vllm/model_executor/models/deepseek_v2.py | 27 +- vllm/platforms/rocm.py | 27 +- vllm/v1/attention/backends/mla/common.py | 55 +- .../attention/backends/mla/rocm_aiter_mla.py | 9 +- 25 files changed, 1194 insertions(+), 925 deletions(-) create mode 100644 vllm/_aiter_ops.py delete mode 100644 vllm/attention/ops/rocm_aiter_mla.py delete mode 100644 vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 633e23eea33e..ee224e6922fb 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | -| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] | +| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 014df1fa111f..c27cf2468ede 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,6 +6,8 @@ """ import functools +import importlib +import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -20,6 +22,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context @@ -412,14 +415,12 @@ def test_mixtral_moe( huggingface.""" # clear the cache before every test - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) + # Force reload aiter_ops to pick up the new environment variables. + if "rocm_aiter_ops" in sys.modules: + importlib.reload(rocm_aiter_ops) - is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 41419553aa83..9121284de85b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import ( @@ -15,9 +16,6 @@ dispatch_topk_func, vllm_topk_softmax, ) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_rocm_rmsnorm_func, @@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - topk_func = dispatch_topk_func() - is_rocm_aiter_moe_enabled.cache_clear() - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax, - ) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_topk_dispatch(use_rocm_aiter: bool): + topk_func = dispatch_topk_func(use_rocm_aiter) - assert topk_func == rocm_aiter_topk_softmax + if current_platform.is_rocm() and use_rocm_aiter: + assert topk_func == rocm_aiter_ops.topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.skipif( not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" ) def test_rms_norm_dispatch( - add_residual: bool, - dtype: torch.dtype, - use_rocm_aiter: str, - use_rocm_aiter_norm: str, - monkeypatch, + add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool ): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) should_use_rocm_aiter = ( current_platform.is_rocm() - and int(use_rocm_aiter) - and int(use_rocm_aiter_norm) + and use_rocm_aiter and dtype in RMS_NORM_SUPPORTED_DTYPES ) if add_residual and should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add elif should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + assert rms_norm_func == rocm_aiter_ops.rms_norm elif add_residual: assert rms_norm_func == fused_add_rms_norm else: diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py new file mode 100644 index 000000000000..9a4b5f3399be --- /dev/null +++ b/vllm/_aiter_ops.py @@ -0,0 +1,941 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer + + +def is_aiter_found() -> bool: + from importlib.util import find_spec + + return find_spec("aiter") is not None + + +# `find_spec` is not torch.compile compatible. +# In cases where aiter availability might have +# been checked in forward passes that are torch compiled. +# we keep this global outside to not cause torch compile breaks. +IS_AITER_FOUND = is_aiter_found() + + +def if_aiter_supported(func: Callable) -> Callable: + """Decorator that only executes the function if + ROCm AITER package is supported on gfx9 archs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # checks the platform, device arch and aiter library existance. + + from vllm.platforms.rocm import on_gfx9 + + if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND: + return func(*args, **kwargs) + else: + # Return None or do nothing if not supported + return None + + return wrapper + + +def _rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def _rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = ActivationType(activation_method) + + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) + + +def _rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_topk_softmax_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + from aiter import topk_softmax + + topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + + +def _rocm_aiter_topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + pass + + +def _rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + from aiter import biased_grouped_topk + + biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + +def _rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_grouped_topk_impl( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + is_softmax = scoring_func == "softmax" + from aiter import grouped_topk + + grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + is_softmax, + routed_scaling_factor, + ) + + +def _rocm_aiter_grouped_topk_fake( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + +def _rocm_aiter_mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +def _rocm_aiter_gemm_w8a8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) + + +def _rocm_aiter_gemm_w8a8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + +def _rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + from aiter import rms_norm + + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rms_norm(x, weight, variance_epsilon) + + +def _rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + return torch.empty_like(x) + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import rmsnorm2d_fwd_with_add + + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + rmsnorm2d_fwd_with_add( + output, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return output, residual_out + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +# Global flag to ensure ops are registered only once +_OPS_REGISTERED = False + + +class rocm_aiter_ops: + _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + + @classmethod + @if_aiter_supported + def is_enabled(cls) -> bool: + """Verifies device specs and availability of aiter main env variable.""" + return cls._AITER_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._LINEAR_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_fp8_enaled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + + @classmethod + @if_aiter_supported + def is_rmsnorm_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._RMSNORM_ENABLED + + @classmethod + @if_aiter_supported + def is_fused_moe_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FMOE_ENABLED + + @classmethod + @if_aiter_supported + def is_fusion_moe_shared_experts_enabled(cls) -> bool: + return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED + + @classmethod + @if_aiter_supported + def is_mla_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MLA_ENABLED + + @classmethod + @if_aiter_supported + def is_mha_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MHA_ENABLED + + @classmethod + @if_aiter_supported + def is_pa_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_triton_unified_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_fp8bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + + @classmethod + @if_aiter_supported + def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM + + @classmethod + @if_aiter_supported + def is_triton_rotary_embed_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED + + @staticmethod + @if_aiter_supported + def register_ops_once() -> None: + global _OPS_REGISTERED + if not _OPS_REGISTERED: + tags = ( + tuple() + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ) + + # register all the custom ops here + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=_rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=_rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fused_moe", + op_func=_rocm_aiter_fused_moe_impl, + mutates_args=[], + fake_impl=_rocm_aiter_fused_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=_rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=_rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_biased_grouped_topk", + op_func=_rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_biased_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_grouped_topk", + op_func=_rocm_aiter_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_mla_decode_fwd", + op_func=_rocm_aiter_mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=_rocm_aiter_mla_decode_fwd_fake, + tags=tags, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8", + op_func=_rocm_aiter_gemm_w8a8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=_rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=_rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + _OPS_REGISTERED = True + + @staticmethod + def rms_norm2d_with_add( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( + x, residual, weight, variance_epsilon + ) + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + + @staticmethod + def gemm_w8a8( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype) + + @staticmethod + def gemm_w8a8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + A, B, As, Bs, output_dtype + ) + + @staticmethod + def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation_method, + quant_method, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + @staticmethod + def asm_moe_tkw1( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale, + fc2_scale, + fc1_smooth_scale, + fc2_smooth_scale, + a16, + per_tensor_quant_scale, + expert_mask, + activation_method, + ) + + @staticmethod + def topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + ) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + return topk_weights, topk_indices + + @staticmethod + def biased_grouped_topk( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + @staticmethod + def grouped_topk( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + scoring_func, + routed_scaling_factor, + ) + + @staticmethod + def mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + logit_cap: float = 0.0, + ): + torch.ops.vllm.rocm_aiter_mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + @staticmethod + def triton_fp4_gemm_dynamic_qaunt( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype | None = torch.bfloat16, + x_scales: torch.Tensor | None = None, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + if x_scales is None: + x_q, x_s = dynamic_mxfp4_quant(x) + else: + x_q = x + x_s = x_scales + + y = torch.empty( + x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) + return y + + @staticmethod + def triton_rotary_embed( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, + is_neox_style: bool, + ): + from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace + + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox_style else 1 + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + rope_cached_thd_positions_2c_fwd_inplace( + positions, + sin, + cos, + query_, + key_, + rotate_style, + reuse_freqs_front_part=True, + is_nope_first=False, + ) + query = query.view(query_shape) + key = key.view(key_shape) + + @staticmethod + def triton_fp8_bmm( + X: torch.Tensor, + WQ: torch.Tensor, + w_scale: torch.Tensor, + group_size: int = 128, + bias: torch.Tensor | None = None, + dtype: torch.dtype | None = torch.bfloat16, + splitK: int | None = None, + YQ: torch.Tensor | None = None, + transpose_bm: bool | None = False, + config: dict | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, + ) + + return aiter_triton_fp8_bmm( + X, + WQ, + w_scale, + group_size=group_size, + bias=bias, + dtype=dtype, + splitK=splitK, + YQ=YQ, + transpose_bm=transpose_bm, + config=config, + ) + + @staticmethod + def triton_gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + @staticmethod + def per_1x128_fp8_quant( + input_2d: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Only applies quantization method for fp8 data type only.""" + from aiter import QuantType, dtypes, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8) + + @staticmethod + def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: + return (n, k) in [ + (1024, 8192), + (2112, 7168), + (3072, 1536), + (32768, 8192), + (4096, 7168), + (4608, 7168), + (512, 7168), + (7168, 2048), + (7168, 256), + (8192, 1024), + (8192, 32768), + ] + + @staticmethod + def shuffle_weight( + self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> torch.Tensor: + from aiter.ops.shuffle import shuffle_weight + + return shuffle_weight(tensor, layout=layout) + + @staticmethod + def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. + + Args: + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). + + Returns: + A Tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) + + +rocm_aiter_ops.register_ops_once() diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py deleted file mode 100644 index 6308f63cc4e7..000000000000 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -import torch - -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer - - -def get_aiter_mla_metadata( - max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device -) -> tuple[torch.Tensor, ...]: - paged_kv_indices = torch.zeros( - max_batch_size * max_block_per_batch, dtype=torch.int32, device=device - ) - paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device) - paged_kv_last_page_lens = torch.full( - (max_batch_size,), block_size, dtype=torch.int32 - ) - qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr - - -def aiter_mla_decode_fwd( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - sm_scale: float, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - logit_cap: float = 0.0, -): - torch.ops.vllm.rocm_aiter_mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - max_seqlen_qo, - kv_indptr, - kv_indices, - kv_last_page_lens, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_impl( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - from aiter.mla import mla_decode_fwd - - mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_fake( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - pass - - -if current_platform.is_rocm(): - if is_torch_equal_or_newer("2.7.0"): - tags = () - else: - tags = ((torch.Tag.needs_fixed_stride_order,),) - direct_register_custom_op( - op_name="rocm_aiter_mla_decode_fwd", - op_func=mla_decode_fwd_impl, - mutates_args=["o"], - fake_impl=mla_decode_fwd_fake, - tags=tags, - ) diff --git a/vllm/envs.py b/vllm/envs.py index 078e5c38f0f4..30c62e90e9fb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -109,7 +109,7 @@ VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_TRITON_ROPE: bool = False + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True @@ -926,8 +926,8 @@ def get_vllm_port() -> int | None: ), # Whether to use aiter rope. # By default is disabled. - "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") + "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. @@ -1589,7 +1589,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", - "VLLM_ROCM_USE_TRITON_ROPE", + "VLLM_ROCM_USE_AITER_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_AITER_TRITON_GEMM", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7ad3ce1397b3..2e042d85fcfc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,6 +14,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -55,8 +56,6 @@ from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer -from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled - logger = init_logger(__name__) @@ -1089,11 +1088,11 @@ def vllm_topk_softmax( return topk_weights, topk_indices -def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax - - return rocm_aiter_topk_softmax +def dispatch_topk_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_softmax return vllm_topk_softmax @@ -1121,7 +1120,7 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - topk_func = dispatch_topk_func() + topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) topk_weights, topk_ids = topk_func( topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e69ead074c50..45b0f50a7997 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,6 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( @@ -41,8 +42,6 @@ ) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, ) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( @@ -92,13 +91,11 @@ def _eplb_map_to_physical_and_record( return topk_ids eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record +from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_grouped_topk, +) -if is_rocm_aiter_moe_enabled(): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk_aiter, - ) -else: - from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: @@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -620,13 +618,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Padding the weight for better performance on ROCm layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - # Lazy import to avoid importing triton. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights, - ) if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -1002,6 +996,7 @@ def determine_expert_map( global_num_experts: int, expert_placement_strategy: ExpertPlacementStrategy = "linear", num_fused_shared_experts: int = 0, + return_expert_mask: bool = False, ) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: """ Calculates how many experts should be assigned to each rank for EP and @@ -1064,7 +1059,7 @@ def determine_expert_map( ) expert_mask = None - if is_rocm_aiter_moe_enabled(): + if return_expert_mask: expert_mask = torch.ones( (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 ) @@ -1292,14 +1287,18 @@ def __init__( self.logical_replica_count: torch.Tensor | None = None # ROCm aiter shared experts fusion + self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.aiter_fmoe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + self.num_fused_shared_experts = ( n_shared_experts - if n_shared_experts is not None - and is_rocm_aiter_fusion_shared_expert_enabled() + if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled else 0 ) if ( - not is_rocm_aiter_fusion_shared_expert_enabled() + not self.aiter_fmoe_shared_expert_enabled and self.num_fused_shared_experts != 0 ): raise ValueError( @@ -1346,6 +1345,7 @@ def __init__( global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) @@ -1570,13 +1570,16 @@ def update_expert_map(self): ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) - self._init_aiter_shared_experts_topK_buffer( - vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size - ) + if self.aiter_fmoe_shared_expert_enabled: + self._init_aiter_shared_experts_topK_buffer( + vllm_config=get_current_vllm_config(), + dp_size=get_dp_group().world_size, + ) def _load_per_tensor_weight_scale( self, @@ -1753,20 +1756,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _init_aiter_shared_experts_topK_buffer( self, vllm_config: VllmConfig, dp_size: int ): - if is_rocm_aiter_fusion_shared_expert_enabled(): - if self.num_fused_shared_experts > 0: - init_aiter_topK_meta_data( - n_routed_experts=self.global_num_experts, - n_shared_experts=self.num_fused_shared_experts, - top_k=self.top_k, - tp_rank=self.ep_rank if self.use_ep else self.tp_rank, - tp_size=self.ep_size if self.use_ep else self.tp_size, - shared_experts_score=1.0, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens - * dp_size, - is_EP=self.use_ep, - ) - self.local_num_experts += self.num_fused_shared_experts + if self.num_fused_shared_experts > 0: + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens + * dp_size, + is_EP=self.use_ep, + ) + self.local_num_experts += self.num_fused_shared_experts @overload def weight_loader( @@ -2208,15 +2210,16 @@ def select_experts( elif use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - if is_rocm_aiter_moe_enabled(): - if not is_rocm_aiter_fusion_shared_expert_enabled(): + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): assert num_fused_shared_experts == 0 grouped_topk_impl = partial( - grouped_topk_aiter, + rocm_aiter_grouped_topk, num_fused_shared_experts=num_fused_shared_experts, ) else: grouped_topk_impl = grouped_topk + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, @@ -2448,7 +2451,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, @@ -2612,7 +2615,7 @@ def forward_impl( use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index e18514ad43f6..8f05828d74f5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum -from functools import cache, lru_cache +from functools import lru_cache import torch -from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, ) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): @@ -37,27 +35,6 @@ class ActivationMethod(IntEnum): GELU = 1 -@cache -def is_rocm_aiter_moe_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_MOE - and envs.VLLM_ROCM_USE_AITER - ) - - -@cache -def use_mxfp4_aiter_moe() -> bool: - return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER - - -@cache -def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: - return ( - envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled() - ) - - aiter_topK_meta_data = None @@ -114,250 +91,6 @@ def init_aiter_topK_meta_data( aiter_topK_meta_data = (total_topk_weights, total_topk_ids) -def rocm_aiter_asm_moe_tkw1_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - from aiter import ActivationType - from aiter.fused_moe_bf16_asm import asm_moe_tkw1 - - activation = ActivationType(activation_method) - - return asm_moe_tkw1( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - per_tensor_quant_scale=per_tensor_quant_scale, - expert_mask=expert_mask, - activation=activation, - ) - - -def rocm_aiter_asm_moe_tkw1_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_topk_softmax_impl( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - from aiter import topk_softmax - - topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - - -def rocm_aiter_topk_softmax_fake( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - pass - - -def rocm_aiter_biased_grouped_topk_impl( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import biased_grouped_topk - - biased_grouped_topk( - gating_output, - correction_bias, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - routed_scaling_factor, - ) - - -def rocm_aiter_biased_grouped_topk_fake( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_grouped_topk_impl( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import grouped_topk - - grouped_topk( - gating_output, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - scoring_func, - routed_scaling_factor, - ) - - -def rocm_aiter_grouped_topk_fake( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_fused_moe_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - activation = ActivationType(activation_method) - quant_type = QuantType(quant_method) - - return fused_moe( - hidden_states, - w1, - w2, - topk_weight, - topk_ids, - expert_mask, - activation, - quant_type, - doweight_stage1, - w1_scale, - w2_scale, - a1_scale, - a2_scale, - ) - - -def rocm_aiter_fused_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_asm_moe_tkw1", - op_func=rocm_aiter_asm_moe_tkw1_impl, - fake_impl=rocm_aiter_asm_moe_tkw1_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_fused_moe", - op_func=rocm_aiter_fused_moe_impl, - fake_impl=rocm_aiter_fused_moe_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_topk_softmax", - op_func=rocm_aiter_topk_softmax_impl, - mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], - fake_impl=rocm_aiter_topk_softmax_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_biased_grouped_topk", - op_func=rocm_aiter_biased_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_biased_grouped_topk_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_grouped_topk", - op_func=rocm_aiter_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_grouped_topk_fake, - ) - - def rocm_aiter_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk( ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): assert aiter_topK_meta_data is not None, ( "AITER topK meta data is not initialized. " "Please ensure that init_aiter_topK_meta_data " @@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk( topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: - torch.ops.vllm.rocm_aiter_biased_grouped_topk( + rocm_aiter_ops.biased_grouped_topk( gating_output, e_score_correction_bias.to(gating_output.dtype), topk_weights, @@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk( ) else: assert scoring_func == "softmax" or scoring_func == "sigmoid" - torch.ops.vllm.rocm_aiter_grouped_topk( + rocm_aiter_ops.grouped_topk( gating_output, topk_weights, topk_ids, @@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk( routed_scaling_factor=routed_scaling_factor, ) - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): return total_topk_weights, total_topk_ids return topk_weights, topk_ids @@ -464,7 +203,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + return rocm_aiter_ops.asm_moe_tkw1( hidden_states, w1, w2, @@ -482,7 +221,9 @@ def rocm_aiter_fused_experts( else: quant_method = QuantMethod.NO.value - + # quark moe for mxfp4 w_dtype + if quant_config.use_mxfp4_w4a16: + quant_method = QuantMethod.BLOCK_1X32.value # w8a8 block-scaled if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( @@ -507,7 +248,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_fused_moe( + return rocm_aiter_ops.fused_moe( hidden_states, w1, w2, @@ -522,39 +263,3 @@ def rocm_aiter_fused_experts( a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, ) - - -def rocm_aiter_topk_softmax( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> tuple[torch.Tensor, ...]: - torch.ops.vllm.rocm_aiter_topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - return topk_weights, topk_indices - - -def shuffle_weights( - *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) -) -> tuple[torch.Tensor, ...]: - """ - Applies shuffle_weight function from AITER to each - input tensor and returns them. - - Rearranges (shuffles) the input tensor/s - into a specified block layout for optimized computation. - - Args: - *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the block sizes used to divide - the tensors during shuffling. Default is (16, 16). - - Returns: - A Tuple of shuffled tensors. - """ - from aiter.ops.shuffle import shuffle_weight - - return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a883ac81f41e..8cc374ac9155 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,18 +6,13 @@ import torch.nn as nn import torch.nn.functional as F -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, vllm_is_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_aiter_rmsnorm_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER def rms_norm( @@ -58,80 +53,34 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm_impl( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +def poly_norm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: - import aiter as rocm_aiter - - if x.dim() > 2: - x_original_shape = x.shape - x = x.reshape(-1, x_original_shape[-1]) - x = rocm_aiter.rms_norm(x, weight, variance_epsilon) - return x.reshape(x_original_shape) - - return rocm_aiter.rms_norm(x, weight, variance_epsilon) - + from vllm import _custom_ops as ops -def rocm_aiter_rmsnorm2d_fwd_with_add_impl( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - import aiter as rocm_aiter - - residual_out = torch.empty_like(residual) - output = torch.empty_like(x) - rocm_aiter.rmsnorm2d_fwd_with_add( - output, # output - x, # input - residual, # residual input - residual_out, # residual output + out = torch.empty_like(x) + ops.poly_norm( + out, + x, weight, + bias, variance_epsilon, ) - return output, residual_out - - -def rocm_aiter_rms_norm_fake( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - return torch.empty_like(x) - - -def rocm_aiter_rmsnorm2d_fwd_with_add_fake( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_rms_norm", - op_func=rocm_aiter_rms_norm_impl, - fake_impl=rocm_aiter_rms_norm_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_rmsnorm2d_fwd_with_add", - op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, - fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, - ) + return out -def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): - use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ +def dispatch_rocm_rmsnorm_func( + with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False +): + use_aiter = use_aiter and dtype in [ torch.float16, torch.bfloat16, ] if use_aiter and with_fused_add: - return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + return rocm_aiter_ops.rms_norm2d_with_add if use_aiter: - return torch.ops.vllm.rocm_aiter_rms_norm + return rocm_aiter_ops.rms_norm # fall back to CUDA implementation if with_fused_add: @@ -169,11 +118,14 @@ def __init__( self.weight = nn.Parameter(self.weight) if current_platform.is_rocm(): + aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, dtype=weight_dtype + with_fused_add=False, + dtype=weight_dtype, + use_aiter=aiter_rmsnorm_enabled, ) self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype + with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled ) @staticmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d95d49eddfe3..d32ae6674ee6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -582,11 +583,8 @@ def __init__( # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( @@ -829,12 +827,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b8..6da136cbc8f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -7,12 +7,12 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) ) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() if self.weight_block_size is not None: assert not self.is_static_input_scheme diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ce40645782e5..e4e1cbff712f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -56,7 +57,6 @@ ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -369,7 +369,7 @@ def __init__(self, quant_config: Fp8Config): if vllm_is_batch_invariant(): self.use_marlin = False - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size @@ -869,12 +869,8 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - shuffle_weights, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -916,7 +912,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -962,7 +958,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1042,7 +1038,7 @@ def process_weights_after_loading(self, layer: Module) -> None: start += shard_size if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bc..f5cd91469b78 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -4,54 +4,14 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig -def rocm_aiter_gemm_w8a8_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - from aiter import gemm_a8w8_CK - - # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects - # a to be [M, K] - # b to be [N, K] - # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) - - -def rocm_aiter_gemm_w8a8_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8", - op_func=rocm_aiter_gemm_w8a8_impl, - fake_impl=rocm_aiter_gemm_w8a8_fake, - ) - - class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -75,7 +35,7 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + "installed on ROCm.", ) # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + if not (rocm_aiter_ops.is_linear_enabled()): return ( False, "AiterScaledMMLinearKernel is disabled. " @@ -157,6 +117,4 @@ def apply_weights( # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return torch.ops.vllm.rocm_aiter_gemm_w8a8( - x_q, w_q.t(), x_s, w_s, bias, out_dtype - ) + return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index eca6b0cb1d8e..30772c3665b0 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -21,10 +22,6 @@ ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - use_mxfp4_aiter_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) @@ -122,7 +119,7 @@ def __init__( if current_platform.is_rocm(): self.use_marlin = False - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() def create_weights( self, @@ -309,12 +306,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -470,13 +463,15 @@ def __init__( "not implemented. Please open an issue." ) + self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() + self.emulate = not current_platform.supports_mx() or not ( - use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" ) if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " + f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " @@ -656,28 +651,18 @@ def apply( ) if not self.emulate: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - aiter_acts = { - ActivationType.No.name.lower(): ActivationType.No, - ActivationType.Silu.name.lower(): ActivationType.Silu, - ActivationType.Gelu.name.lower(): ActivationType.Gelu, - } - assert activation in aiter_acts, ( - f"Aiter CK fp4 MoE doesn't support activation {activation}" - ) - out = fused_moe( + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + out = rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights, - topk_ids, - quant_type=QuantType.per_1x32, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - activation=aiter_acts[activation], - doweight_stage1=False, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + quant_config=self.moe_quant_config, ) else: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c25c522dea55..007e78e68d5c 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -31,6 +31,13 @@ logger = init_logger(__name__) +# TODO: move registration of custom op to aiter_ops.py +# `from vllm._aiter_ops import rocm_aiter_ops` +# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()` +# for envs checks which does not require @cache anymore. +# triton kernel is torch compile compatible. +# does not require direct registeration. +# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`. @cache def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: return ( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7fecda2166ef..63726c07b7d1 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -68,78 +69,6 @@ def cutlass_scaled_mm( ) -def rocm_aiter_gemm_w8a8_blockscale_impl( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - def is_aiter_triton_kernel_tuned(n, k): - return (n, k) in [ - (1024, 8192), - (2112, 7168), - (3072, 1536), - (32768, 8192), - (4096, 7168), - (4608, 7168), - (512, 7168), - (7168, 2048), - (7168, 256), - (8192, 1024), - (8192, 32768), - ] - - n, k = weight.shape - if input_scale is not None: - q_input = input_2d - elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k): - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - - # MI350 case uses triton kernel - q_input, input_scale = per_token_group_quant_fp8( - input_2d, - group_size, - column_major_scales=False, - use_ue8m0=False, - ) - else: - # MI300 uses tuned AITER ASM/C++ kernel - import aiter as rocm_aiter - from aiter import gemm_a8w8_blockscale, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) - q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 - ) - - return gemm_a8w8_blockscale( - q_input, weight, input_scale, weight_scale, dtype=output_dtype - ) - - -def rocm_aiter_gemm_w8a8_blockscale_fake( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = input_2d.shape[0] - n = weight.shape[0] - return torch.empty(m, n, dtype=output_dtype, device=input_2d.device) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - ) - - # TODO we should be able to change the type of block_size to GroupShape # after we resolve GroupShape compilation issue # https://github.com/vllm-project/vllm/issues/25270 @@ -385,14 +314,40 @@ def _run_aiter( input_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - input_2d, - weight, - input_scale, - weight_scale, - self.act_quant_group_shape.col, - input_2d.dtype, - ) + + n, k = weight.shape + if input_scale is not None: + q_input = input_2d + + # MI350 case uses triton kernel + if ( + not current_platform.is_fp8_fnuz() + and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) + ): + q_input, input_scale = per_token_group_quant_fp8( + input_2d, + self.act_quant_group_shape.col, + column_major_scales=False, + use_ue8m0=False, + ) + return rocm_aiter_ops.triton_gemm_a8w8_blockscale( + q_input, + weight, + input_scale, + weight_scale, + input_2d.dtype, + ) + + # MI300 uses tuned AITER ASM/C++ kernel + else: + q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) + return rocm_aiter_ops.gemm_w8a8_blockscale( + q_input, + weight, + input_scale, + weight_scale, + input_2d.dtype, + ) def _run_triton( self, @@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) -def check_aiter_fp8_linear_support() -> bool: - """AITER is only supported on ROCm for MI3XX""" - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - ) - - def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 380431e86435..7fe902807a74 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -472,7 +472,7 @@ def apply( # Example: # When the number of token is 1, per-token scale is [[1]] # When per-tensor scale is [1] or (). - per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 + per_tensor_weights = weight_scale.numel() == 1 per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 # TODO(luka) do this dispatch during init (after ScaledMM refactor) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 91276320df4d..2ef54e75df44 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,13 +4,10 @@ import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch -from .rocm_aiter_rope_ops import ( - is_rocm_triton_rotary_embedding_enabled, - rocm_aiter_rotary_emb, -) @CustomOp.register("rotary_embedding") @@ -48,8 +45,8 @@ def __init__( cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_triton_rotary_embedding_enabled = ( - is_rocm_triton_rotary_embedding_enabled() + self.is_rocm_triton_rotary_embed_enabled = ( + rocm_aiter_ops.is_triton_rotary_embed_enabled() ) def _compute_inv_freq(self, base: float) -> torch.Tensor: @@ -169,9 +166,9 @@ def forward_hip( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if self.is_rocm_triton_rotary_embedding_enabled: + if self.is_rocm_triton_rotary_embed_enabled: self._match_cos_sin_cache_dtype(query) - rocm_aiter_rotary_emb( + rocm_aiter_ops.triton_rotary_embed( positions, query, key, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index d9134f05fddf..e72834e473c1 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -146,6 +146,15 @@ def forward_native( key = key_rot return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.forward_native(positions, query, key, offsets) + def forward_cuda( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py deleted file mode 100644 index a01d14f7b3a1..000000000000 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.envs as envs -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_triton_rotary_embedding_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_TRITON_ROPE - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter.ops.triton.rope as ops - - ops.rope_cached_thd_positions_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -if is_rocm_triton_rotary_embedding_enabled(): - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_with_key_forward_triton", - op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, - mutates_args=["key", "query"], - fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake, - dispatch_key=current_platform.dispatch_key, - ) - - -def rocm_aiter_rotary_emb( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - cos_sin_cache: torch.Tensor, - head_size: int, - rotary_dim: int, - is_neox_style: bool, -): - num_tokens = positions.numel() - cos, sin = cos_sin_cache.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - rotate_style = 0 if is_neox_style else 1 - - query = query.view(num_tokens, -1, head_size) - key = key.view(num_tokens, -1, head_size) - query_ = query[..., :rotary_dim] - key_ = key[..., :rotary_dim] - positions = positions.view(*query.shape[:1]) - torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton( - positions, - sin, - cos, - query_, - key_, - rotate_style, - False, - ) - query = query.view(query_shape) - key = key.view(key_shape) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 63eaf63cc3c4..38189e17f7d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -33,6 +33,7 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton @@ -50,10 +51,6 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -294,10 +291,8 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - if ( - config.n_shared_experts is None - or is_rocm_aiter_fusion_shared_expert_enabled() - ): + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled: self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -330,14 +325,14 @@ def __init__( # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 - if not is_rocm_aiter_moe_enabled() + if not self.is_rocm_aiter_moe_enabled else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, n_shared_experts=config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() else None, ) @@ -371,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - if not is_rocm_aiter_moe_enabled(): + if not self.is_rocm_aiter_moe_enabled: final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None @@ -1428,6 +1423,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1456,7 +1454,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: num_experts=self.config.n_routed_experts + ( self.config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_moe_shared_expert_enabled else 0 ), num_redundant_experts=self.num_redundant_experts, @@ -1472,9 +1470,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if spec_layer is not None: continue # skip spec decode layers for main model - is_fuse_shared_experts_layer = ( - is_rocm_aiter_fusion_shared_expert_enabled() - and ("mlp.shared_experts" in name) + is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( + "mlp.shared_experts" in name ) for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1abd6300036d..e6536a02a73d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: + from vllm._aiter_ops import rocm_aiter_ops + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER) + and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None ) @@ -202,12 +204,15 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: from importlib.util import find_spec + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + if rocm_aiter_ops.is_mha_enabled(): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. return _Backend.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: @@ -228,19 +233,23 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") - if use_mla: - from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( - is_aiter_mla_enabled, + + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) + if use_mla: if selected_backend is None: selected_backend = ( _Backend.ROCM_AITER_MLA - if is_aiter_mla_enabled() or block_size == 1 + if rocm_aiter_ops.is_mla_enabled() or block_size == 1 else _Backend.TRITON_MLA ) @@ -265,12 +274,12 @@ def get_attn_backend_cls( logger.info("Using FlexAttention backend.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + rocm_aiter_ops.is_mha_enabled() ) or selected_backend == _Backend.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend.") return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + rocm_aiter_ops.is_triton_unified_attn_enabled() ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend.") return ( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 40ce12c4bd75..e38f7bcfa44e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -198,6 +198,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, @@ -270,28 +271,15 @@ class QueryLenSupport(Enum): flashinfer_available = False -def is_rocm_aiter_fp8bmm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP8BMM - and envs.VLLM_ROCM_USE_AITER - ) - - -if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 - ) - - def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn - ): - DTYPE_MAX = torch.finfo(dtype).max - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) - scale = DTYPE_MAX / amax - x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() logger = init_logger(__name__) @@ -1109,6 +1097,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1158,7 +1147,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1187,7 +1176,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1196,7 +1185,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1208,10 +1197,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm( + x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) # Convert from (B, N, V) to (B, N * V) @@ -1571,7 +1559,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1600,7 +1588,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1609,7 +1597,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1958,7 +1946,6 @@ def forward( # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) @@ -1966,9 +1953,9 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm( + decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, self.W_K, self.W_K_scale, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 4ad7236eb1be..5757aeadba05 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,9 +6,8 @@ import torch -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionLayer -from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import ( @@ -22,10 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA - - class AiterMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -284,7 +279,7 @@ def _forward_decode( # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd( + rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o, From d0e186c16f0d62af8c128e2dc7c94cde1387ac02 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Nov 2025 00:30:06 +0800 Subject: [PATCH 20/49] [V0 Deprecation] Remove unused `context_len` and `seq_len` from M-RoPE (#28395) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/ernie45_vl.py | 3 --- vllm/model_executor/models/glm4_1v.py | 3 --- vllm/model_executor/models/glm4v.py | 3 --- vllm/model_executor/models/interfaces.py | 4 ---- vllm/model_executor/models/keye.py | 3 --- vllm/model_executor/models/keye_vl1_5.py | 3 --- vllm/model_executor/models/paddleocr_vl.py | 3 --- vllm/model_executor/models/qwen2_5_omni_thinker.py | 3 --- vllm/model_executor/models/qwen2_5_vl.py | 3 --- vllm/model_executor/models/qwen2_vl.py | 3 --- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 2 -- vllm/model_executor/models/qwen3_vl.py | 4 +--- vllm/model_executor/models/transformers/multimodal.py | 4 +--- 13 files changed, 2 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 7c1eba103ae7..f287cff12086 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1435,8 +1435,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1569,7 +1567,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 121e84469c52..b9cd3545ec45 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1622,8 +1622,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1754,7 +1752,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 2de1e4810952..ebf6934dddea 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -625,8 +625,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -758,7 +756,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b634c7ec7d67..d6a8f86d998b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -995,8 +995,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1012,8 +1010,6 @@ def get_mrope_input_positions( image_grid_thw: Image grid dimensions (t, h, w) video_grid_thw: Video grid dimensions (t, h, w) second_per_grid_ts: Seconds per grid timestep for videos - context_len: Context length - seq_len: Sequence length audio_feature_lengths: Audio feature lengths for multimodal models use_audio_in_video: Whether to use audio in video for interleaving diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 5f8659a3064e..42f16ad9f3b3 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1630,8 +1630,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1759,6 +1757,5 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 13e5b2d5f157..6f95a59d36d2 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -600,8 +600,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -729,6 +727,5 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 377b41a35578..631475c964c0 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -1179,8 +1179,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1293,7 +1291,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 7e970ebbe2bb..fac281d2caf4 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -927,8 +927,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1125,7 +1123,6 @@ def get_mrope_input_positions( mrope_position_delta = ( torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) ) - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d337f1606943..48834ba699e4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1118,8 +1118,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1232,7 +1230,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9206ac8f9d03..b3999e6c934e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1240,8 +1240,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1360,7 +1358,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f20e67902721..da489a812f55 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -1417,8 +1417,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 2d8f431bb8fa..fe0124ef3258 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1419,8 +1419,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1519,7 +1517,7 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta def get_language_model(self) -> torch.nn.Module: diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 10abd8659536..476074542e6a 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -371,8 +371,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -390,7 +388,7 @@ def get_mrope_input_positions( video_grid_thw=video_grid_thw, ) - mrope_positions = mrope_positions[:, 0, context_len:seq_len] + mrope_positions = mrope_positions[:, 0] mrope_position_delta = mrope_position_delta[0].item() return mrope_positions, mrope_position_delta From b039bfda8f72b442d42dbdac40f51572bf045ad1 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 10 Nov 2025 12:21:52 -0500 Subject: [PATCH 21/49] [Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- csrc/quantization/activation_kernels.cu | 15 ++++++++++----- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 5 ++++- .../layers/fused_moe/batched_deep_gemm_moe.py | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 6fcd246f63c5..2521b2797e2c 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant( // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. + static constexpr int GROUP_SIZE = 128; + TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(y_s.dtype() == torch::kFloat32); - TORCH_CHECK(input.size(-1) % 256 == 0); + TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); using Idx_t = int64_t; @@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_counts_e = tokens_per_expert.stride(0); - static constexpr int GROUP_SIZE = 128; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ @@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant( static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + int const NUM_GROUPS = H / GROUP_SIZE; if (!use_ue8m0) { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); } } else { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); } diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 97a55c37b9a3..420dbbffaac0 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -25,6 +25,7 @@ (8, 16, 128 * 2, fp8_dtype), (8, 16, 128 * 3, fp8_dtype), (8, 64, 7168, fp8_dtype), + (8, 128, 128 * 33, fp8_dtype), (8, 128, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype), @@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): ) # Run the SiLU V2 kernel + # TODO (varun): use_e8m0 is set to false as the reference impl does + # not handle that case. y_q, y_s = persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, group_size=group_size + y, tokens_per_expert, group_size=group_size, use_ue8m0=False ) torch.cuda.synchronize() diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 095ec966ea7e..b8a97e92ab79 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, group_size: int = 128, + use_ue8m0: bool | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is @@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant( device=y.device, ) - use_ue8m0 = is_deep_gemm_e8m0_used() + use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() cuda_arch = current_platform.get_device_capability( device_id=y.device.index From 34553b9d2702dd2a27a578fec819e88e76dcbfb4 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 10 Nov 2025 09:34:57 -0800 Subject: [PATCH 22/49] [Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../model_executor/layers/fused_moe/config.py | 21 +++++++++++++++ .../layers/fused_moe/flashinfer_trtllm_moe.py | 26 +++++++++---------- vllm/model_executor/layers/fused_moe/layer.py | 20 ++++++++++++++ .../model_executor/layers/quantization/fp8.py | 14 +++++----- .../quantization/utils/flashinfer_utils.py | 23 +++++++++------- vllm/model_executor/models/qwen3_moe.py | 2 ++ vllm/model_executor/models/qwen3_next.py | 2 ++ 7 files changed, 78 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbc3caafcf2f..a7bd64b1c65e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from enum import IntEnum from typing import Optional, Union import torch @@ -91,6 +92,26 @@ def _quant_flags_to_group_shape( return a_shape, w_shape +# The type of method in top-K routing +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups + # -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # RenormalizeNaive: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK: TopK (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6.0 + + @dataclass class FusedMoEQuantDesc: """ diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index f21fe16c5108..51e06ac54f49 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8( w2_weight_scale_inv: torch.Tensor, global_num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: float = 1.0, + routing_method_type: int = RoutingMethodType.DeepSeekV3, + routed_scaling: float | None = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 + assert top_k <= 10 assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 256 - assert global_num_experts <= 256 + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! @@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim( - x.shape[0], top_k, global_num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method + tile_tokens_dim=None, + routing_method_type=routing_method_type, use_shuffled_weight=False, ) @@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( expert_offset: int, local_num_experts: int, block_shape: list[int], + routing_method_type: int, routed_scaling: float = 1.0, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 45b0f50a7997..f86a93e30003 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,6 +31,7 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton @@ -1213,6 +1214,7 @@ def __init__( zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, + routing_method_type: int | None = None, ): super().__init__() @@ -1397,6 +1399,24 @@ def __init__( "Only softmax scoring function is supported for non-grouped topk." ) + # ToDo: Better logic to determine the routing method type + if routing_method_type is not None: + self.routing_method_type = routing_method_type + else: + if scoring_func == "sigmoid": + if self.use_grouped_topk: + self.routing_method_type = RoutingMethodType.DeepSeekV3 + elif self.top_k == 1: + self.routing_method_type = RoutingMethodType.Llama4 + elif self.scoring_func == "softmax": + self.routing_method_type = ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + else: + self.routing_method_type = RoutingMethodType.TopK + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e4e1cbff712f..f5fc750baaea 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,6 +28,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe @@ -1222,22 +1223,20 @@ def apply( assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) + if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert ( - renormalize and use_grouped_topk and custom_routing_function is None - ) e_score_correction_bias = ( e_score_correction_bias.to(x.dtype) if e_score_correction_bias is not None else None ) + routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, @@ -1252,6 +1251,7 @@ def apply( expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, + routing_method_type=routing_method_type, routed_scaling=routed_scaling_factor, ) else: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 50ea049c3d5a..e49d374f154d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. tile_tokens_dim = 8 - # from flashinfer import next_positive_power_of_2 - - # # Guess tokens per expert assuming perfect expert distribution first. - # num_tokens_per_expert = (num_tokens * top_k) // num_experts - # # And pad the number to the next power of 2. - # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # # kernel. - # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + # A factor considering tokens are not perfectly balanced among experts. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-max_tile_tokens_dim tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e6772bb708..d57b82cb0227 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,6 +43,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -171,6 +172,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 55bbad7a8b27..aa7de5aa5f29 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,6 +34,7 @@ fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) @@ -173,6 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: From 6d54336ae550528408e0c84cffb7857c426509f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:53:32 +0100 Subject: [PATCH 23/49] [Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt Signed-off-by: remi Co-authored-by: Russell Bryant --- .../test_backend_guidance.py | 118 ++++++++++++++++++ vllm/v1/structured_output/backend_guidance.py | 10 +- 2 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/v1/structured_output/test_backend_guidance.py diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py new file mode 100644 index 000000000000..771076186a3b --- /dev/null +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import AutoTokenizer + +from vllm.config import StructuredOutputsConfig, VllmConfig +from vllm.config.model import ModelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.backend_guidance import GuidanceBackend +from vllm.v1.structured_output.backend_types import StructuredOutputOptions + +TOKENIZER = "gpt2" + + +def test_backend_guidance_rollback_terminated(): + # Test that the backend guidance successfully rollbacks from a + # terminated state. This can happen with speculative decoding, + # where the draft model proposes EOS and it is verified by the + # guidance backend. In that case we are in a stopped state, but + # it should be reverted in case EOS is not accepted by the target + # model. + vllm_config = VllmConfig( + decoding_config=StructuredOutputsConfig( + backend="guidance", + ) + ) + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + backend = GuidanceBackend( + vllm_config, + tokenizer=tokenizer, + vocab_size=50257, + ) + + grammar = backend.compile_grammar( + StructuredOutputOptions.JSON, '{"type": "object"}' + ) + + prompt = tokenizer.encode('{"a": "b"}') + assert len(prompt) > 1 + dummy_wrong = tokenizer.encode('{"a"}') + for token in prompt: + assert grammar.accept_tokens("", [token]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Giving any other token should also be accepted + assert grammar.accept_tokens("", dummy_wrong) + # Rollback is done from where state was terminated, so from '}' not EOS + grammar.rollback(len(prompt) - 1) + assert not grammar.is_terminated() + assert grammar.validate_tokens([tokenizer.eos_token_id]) == [] + assert grammar.validate_tokens(dummy_wrong) != dummy_wrong + assert grammar.accept_tokens("", prompt[1:]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Rollback of <= 0 should not change the terminated state + grammar.rollback(0) + assert grammar.is_terminated() + grammar.rollback(-1) + assert grammar.is_terminated() + + +def test_grammar_bitmask_with_specdec(): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + prompt = tokenizer.encode('{"a": "b"}') + vllm_config = VllmConfig( + model_config=ModelConfig(tokenizer=TOKENIZER), + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3), + ) + structured_output_manager = StructuredOutputManager(vllm_config) + + for i in range(1, 2): + sampling_params = SamplingParams( + structured_outputs=StructuredOutputsParams( + json='{"type": "object"}', + ), + ) + sampling_params.structured_outputs._backend = "guidance" + + my_req_id = f"my_req_id_{i}" + request = Request( + my_req_id, + prompt_token_ids=prompt[:i], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=tokenizer.eos_token_id, + ) + + structured_output_manager.grammar_init(request) + + def grammar_bitmask(req: Request, tokens: list[int]) -> None: + structured_output_manager.grammar_bitmask( + requests={req.request_id: req}, + structured_output_request_ids={req.request_id: 0}, + scheduled_spec_decode_tokens={req.request_id: tokens}, + ) + # At this point, we rolled-back, so should not be terminated + assert not req.structured_output_request.grammar.is_terminated() + + # The grammar might not yet be compiled, so we wait for it + while not request.structured_output_request._check_grammar_completion(): + continue + + assert request.structured_output_request.grammar.accept_tokens( + request.request_id, prompt[:i] + ) + + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) + grammar_bitmask( + request, prompt[i:] + [tokenizer.eos_token_id] + prompt + ) # EOS not the final token + grammar_bitmask(request, prompt[i:]) # EOS not present + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 00a625e103bd..2962a439dcb3 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar): vocab_size: int printed_error: bool = False terminated: bool = False + rollback_lag: int = 0 def check_error(self): if not self.printed_error: @@ -127,6 +128,8 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """ if self.ll_tokenizer.eos_token in tokens: + if self.ll_matcher.is_stopped() and not self.terminated: + self.rollback_lag = 1 self.terminated = True if self.ll_matcher.is_stopped(): @@ -163,8 +166,11 @@ def validate_tokens(self, tokens: list[int]) -> list[int]: return tokens[:num_tokens] def rollback(self, num_tokens: int) -> None: - self.ll_matcher.rollback(num_tokens) - self.check_error() + if num_tokens > 0: + self.ll_matcher.rollback(num_tokens - self.rollback_lag) + self.terminated = False + self.rollback_lag = 0 + self.check_error() def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped From 9c84ca8293034cdf8a324f7bec3a60101e0e12c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20M=2E=20K=C3=BCbler?= <44084297+jmkuebler@users.noreply.github.com> Date: Mon, 10 Nov 2025 21:06:04 +0100 Subject: [PATCH 24/49] [FA/Chore] Bump FA version for FP8 two-level accumulation (#27889) Signed-off-by: Jonas Kuebler Co-authored-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 931090db50e9..29db9fa273a4 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54 + GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 40d33264c680a8c725b93db6ccce608f99e5c7da Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Nov 2025 12:39:19 -0800 Subject: [PATCH 25/49] [Bugfix][EPLB] Disabled shared expert overlap when EPLB is enabled (#28377) Signed-off-by: Sage Moore Signed-off-by: Sage Moore Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- .../layers/fused_moe/shared_fused_moe.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 6b4a0b8cf073..3d0c5636d6c0 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -28,13 +28,18 @@ def __init__( super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if we are not using - # flashinfer + DP since there is nothing to be gained in this case. - # Disabling the overlap optimization also prevents the shared experts - # from being hidden from torch.compile. + # Disable shared expert overlap if we are using eplb, because of + # correctness issues, or if using flashinfer with DP, since there + # is nothing to be gained in this case. Disabling the overlap + # optimization also prevents the shared experts from being hidden + # from torch.compile. self.use_overlapped = ( use_overlapped - and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + and not ( + # TODO(wentao): find the root cause and remove this condition + self.enable_eplb + or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + ) and self._shared_experts is not None ) From bf6a3d0ff5a69e0a30567f2ad417530c002eaa4e Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Mon, 10 Nov 2025 13:03:21 -0800 Subject: [PATCH 26/49] [Misc] Add more scoping for improved trace (#28329) Signed-off-by: Wei Wei --- vllm/v1/core/sched/scheduler.py | 116 ++++++++++++++-------------- vllm/v1/engine/core.py | 117 ++++++++++++++++++----------- vllm/v1/engine/llm_engine.py | 37 +++++---- vllm/v1/worker/gpu_model_runner.py | 70 +++++++++-------- 4 files changed, 192 insertions(+), 148 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c17b19b58c97..46dc1071b839 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,6 +38,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext logger = init_logger(__name__) @@ -259,49 +260,52 @@ def schedule(self) -> SchedulerOutput: continue # Schedule newly needed KV blocks for the request. - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens, - ) - - if new_blocks is not None: - # The request can be scheduled. - break - - # The request cannot be scheduled. - # Preempt the lowest-priority request. - if self.policy == SchedulingPolicy.PRIORITY: - preempted_req = max( - self.running, - key=lambda r: (r.priority, r.arrival_time), + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, ) - self.running.remove(preempted_req) - if preempted_req in scheduled_running_reqs: - scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[preempted_req.request_id] - req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop(preempted_req.request_id) - req_index -= 1 - else: - preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - preempted_req.num_preemptions += 1 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp - ) + if new_blocks is not None: + # The request can be scheduled. + break - self.waiting.prepend_request(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. Cannot schedule this request. - break + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + req_index -= 1 + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break if new_blocks is None: # Cannot schedule this request. @@ -599,13 +603,14 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) ) - ) # Construct the scheduler output. new_reqs_data = [ @@ -614,13 +619,14 @@ def schedule(self) -> SchedulerOutput: ) for req in scheduled_new_reqs ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, - scheduled_resumed_reqs, - num_scheduled_tokens, - scheduled_spec_decode_tokens, - req_to_new_blocks, - ) + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) # Record the request ids that were scheduled in this step. self.prev_step_scheduled_req_ids.clear() @@ -649,8 +655,8 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta - - self._update_after_schedule(scheduler_output) + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fba018432e0a..c3efd52130cc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -61,6 +61,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -315,17 +316,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, non_block=True) - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) - with self.log_error_detail(scheduler_output): - model_output = future.result() - if model_output is None: - model_output = self.model_executor.sample_tokens(grammar_output) - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step: schedule"): + scheduler_output = self.scheduler.schedule() + + with record_function_or_nullcontext("core step: execute_model"): + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) + + with record_function_or_nullcontext("core step: update_from_output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 @@ -363,32 +368,49 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): - scheduler_output = self.scheduler.schedule() - exec_future = self.model_executor.execute_model( - scheduler_output, non_block=True - ) + with record_function_or_nullcontext("core step_with_batch_queue: schedule"): + scheduler_output = self.scheduler.schedule() + with record_function_or_nullcontext( + "core step_with_batch_queue: execute_model" + ): + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) model_executed = scheduler_output.total_num_scheduled_tokens > 0 if scheduler_output.pending_structured_output_tokens: - # We need to defer sampling until we have processed the model output - # from the prior step. - deferred_scheduler_output = scheduler_output - # Block-wait for execute to return (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - assert exec_result is None + with record_function_or_nullcontext( + "core step_with_batch_queue: pending_structured_output_tokens" + ): + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + # Block-wait for execute to return + # (continues running async on the GPU). + with self.log_error_detail(scheduler_output): + exec_result = exec_future.result() + assert exec_result is None else: - # We aren't waiting for any tokens, get any grammar output immediately. - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with record_function_or_nullcontext( + "core step_with_batch_queue: get_grammar_bitmask" + ): + # We aren't waiting for any tokens, get any grammar + # output immediately. + grammar_output = self.scheduler.get_grammar_bitmask( + scheduler_output + ) # Block-wait for execute to return (continues running async on the GPU). with self.log_error_detail(scheduler_output): exec_result = exec_future.result() if exec_result is None: - # Call sample tokens. - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) + with record_function_or_nullcontext( + "core step_with_batch_queue: sample_tokens" + ): + # Call sample tokens. + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) else: # No sampling required (e.g. all requests finished). future = cast(Future[ModelRunnerOutput], exec_future) @@ -408,27 +430,34 @@ def step_with_batch_queue( # only be called when the scheduler contains requests or the queue # is non-empty. return None, False - - # Block until the next result is available. - future, scheduler_output = batch_queue.pop() - with self.log_error_detail(scheduler_output): - model_output = future.result() - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step_with_batch_queue: model_output"): + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + with record_function_or_nullcontext( + "core step_with_batch_queue: update_from_output" + ): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. - grammar_output = self.scheduler.get_grammar_bitmask( - deferred_scheduler_output - ) - future = self.model_executor.sample_tokens(grammar_output, non_block=True) - batch_queue.appendleft((future, deferred_scheduler_output)) + with record_function_or_nullcontext( + "core step_with_batch_queue: deferred_scheduler_output" + ): + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) + batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e32c74aff313..d27d13840989 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats +from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -280,28 +281,32 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]: return [] # 1) Get EngineCoreOutput from the EngineCore. - outputs = self.engine_core.get_output() + with record_function_or_nullcontext("llm_genine step: get_output"): + outputs = self.engine_core.get_output() # 2) Process EngineCoreOutputs. - iteration_stats = IterationStats() if self.log_stats else None - processed_outputs = self.output_processor.process_outputs( - outputs.outputs, - engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats, - ) - self.output_processor.update_scheduler_stats(outputs.scheduler_stats) + with record_function_or_nullcontext("llm_genine step: process_outputs"): + iteration_stats = IterationStats() if self.log_stats else None + processed_outputs = self.output_processor.process_outputs( + outputs.outputs, + engine_core_timestamp=outputs.timestamp, + iteration_stats=iteration_stats, + ) + self.output_processor.update_scheduler_stats(outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. - self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + with record_function_or_nullcontext("llm_genine step: abort_requests"): + self.engine_core.abort_requests(processed_outputs.reqs_to_abort) # 4) Record stats - if self.logger_manager is not None and outputs.scheduler_stats is not None: - self.logger_manager.record( - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - mm_cache_stats=self.processor.stat_mm_cache(), - ) - self.do_log_stats_with_interval() + with record_function_or_nullcontext("llm_genine step: record_stats"): + if self.logger_manager is not None and outputs.scheduler_stats is not None: + self.logger_manager.record( + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), + ) + self.do_log_stats_with_interval() return processed_outputs.request_outputs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26007d29d61b..9403b5756e05 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2525,7 +2525,7 @@ def execute_model( "after execute_model() returns None." ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("Preprocess"): + with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) @@ -2648,7 +2648,7 @@ def execute_model( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -2659,7 +2659,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("gpu_model_runner: postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2756,12 +2756,12 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - with record_function_or_nullcontext("Sample"): + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): + with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2799,7 +2799,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2826,37 +2826,41 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("EPLB"): + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() - - output = ModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) if not self.use_async_scheduling: return output - - async_output = AsyncGPUModelRunnerOutput( - model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, - logprobs_tensors=sampler_output.logprobs_tensors, - invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, - ) - - # Save ref of sampled_token_ids CPU tensor if the batch contains - # any requests with sampling params that that require output ids. - self.input_batch.set_async_sampled_token_ids( - async_output.sampled_token_ids_cpu, - async_output.async_copy_ready_event, - ) + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) return async_output From 6dec9f61098786690b4ca2140682dbafb849f8d9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 10 Nov 2025 17:01:17 -0500 Subject: [PATCH 27/49] [BugFix] Fix DeepGEMM over-allocating workspace (#28254) Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 484b8aa9d107..86cdd25f2c87 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -215,7 +215,7 @@ def workspace_shapes( ) assert M_sum % block_m == 0 - workspace1 = (M_sum, max(N, K)) + workspace1 = (M_sum, N) workspace2 = (M_sum, max(N // 2, K)) output = (M, K) return (workspace1, workspace2, output) From 4b94ed8f928533b1f7c3a0293790ccb592b49f1a Mon Sep 17 00:00:00 2001 From: Andrew Xia Date: Mon, 10 Nov 2025 14:07:49 -0800 Subject: [PATCH 28/49] [Frontend][2/n] remove empty content from _parse_tool_calls_from_content (#28331) Signed-off-by: Andrew Xia Co-authored-by: Andrew Xia --- vllm/entrypoints/openai/serving_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8ce4ff574699..30b8499b08d5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1375,6 +1375,8 @@ def _parse_tool_calls_from_content( for tool_call in tool_call_info.tool_calls ) content = tool_call_info.content + if content and content.strip() == "": + content = None else: # No tool calls. return None, content From 30700b1cd7de51f191be718215a58f5a8ddcb8aa Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 10 Nov 2025 17:36:11 -0500 Subject: [PATCH 29/49] [CI] Fix Plugin Tests Tests (#28413) Signed-off-by: Robert Shaw --- vllm/config/vllm.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d4ee6f980e6e..0fca967d9083 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -608,17 +608,19 @@ def __post_init__(self): ) current_platform.check_and_update_config(self) - assert ( - self.parallel_config.dcp_kv_cache_interleave_size - <= self.cache_config.block_size - and self.cache_config.block_size - % self.parallel_config.dcp_kv_cache_interleave_size - == 0 - ), ( - f"Block_size({self.cache_config.block_size}) should be " - "greater than or equal to and divisible by dcp_kv_cache_interleave_size " - f"({self.parallel_config.dcp_kv_cache_interleave_size})." - ) + # If DCP, ensure the block size is right. + if self.parallel_config.decode_context_parallel_size > 1: + assert ( + self.parallel_config.dcp_kv_cache_interleave_size + <= self.cache_config.block_size + and self.cache_config.block_size + % self.parallel_config.dcp_kv_cache_interleave_size + == 0 + ), ( + f"Block_size({self.cache_config.block_size}) should be greater " + "than or equal to and divisible by dcp_kv_cache_interleave_size " + f"({self.parallel_config.dcp_kv_cache_interleave_size})." + ) assert ( self.parallel_config.dcp_kv_cache_interleave_size == 1 From 021143561fcffa9bee133d0b3bd311bc5cb3703c Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:13:36 -1000 Subject: [PATCH 30/49] [ROCm] Add missing gemm_a8w8_blockscale import (#28378) Signed-off-by: Yong Hoon Shin --- .../layers/quantization/utils/fp8_utils.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 63726c07b7d1..c63196b89357 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -316,38 +316,39 @@ def _run_aiter( assert self.act_quant_group_shape == GroupShape(1, 128) n, k = weight.shape - if input_scale is not None: - q_input = input_2d - # MI350 case uses triton kernel - if ( + use_triton = ( not current_platform.is_fp8_fnuz() and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) - ): + ) + + if use_triton: + gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale + else: + gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale + + if input_scale is not None: + q_input = input_2d + # MI350 case uses triton kernel + elif use_triton: q_input, input_scale = per_token_group_quant_fp8( input_2d, self.act_quant_group_shape.col, column_major_scales=False, use_ue8m0=False, ) - return rocm_aiter_ops.triton_gemm_a8w8_blockscale( - q_input, - weight, - input_scale, - weight_scale, - input_2d.dtype, - ) - # MI300 uses tuned AITER ASM/C++ kernel else: q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) - return rocm_aiter_ops.gemm_w8a8_blockscale( - q_input, - weight, - input_scale, - weight_scale, - input_2d.dtype, - ) + + return gemm_a8w8_blockscale_op( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + output_dtype=input_2d.dtype, + ) def _run_triton( self, From d17ecc6b19b597615893be6c0eb61c9b4a9c9455 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Tue, 11 Nov 2025 00:33:11 +0100 Subject: [PATCH 31/49] [PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds (#24248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič Signed-off-by: Luka Govedič Signed-off-by: ilmarkov Co-authored-by: Luka Govedič Co-authored-by: Luka Govedič Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 4 +- .../kernels/benchmark_fused_collective.py | 1129 +++++++++++++++++ tests/compile/test_fusions_e2e.py | 7 + vllm/compilation/collective_fusion.py | 132 +- vllm/config/compilation.py | 50 +- vllm/model_executor/layers/fused_moe/layer.py | 45 +- 6 files changed, 1284 insertions(+), 83 deletions(-) create mode 100644 benchmarks/kernels/benchmark_fused_collective.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b81c090fa471..3152cd6488f3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -463,8 +463,8 @@ steps: - pytest -v -s compile/test_multimodal_compile.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 22min - timeout_in_minutes: 35 +- label: PyTorch Fullgraph Test # 27min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 000000000000..38e7fdcf5542 --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time + +import pandas as pd +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +class VllmFusedAllreduce: + def __init__(self, hidden_dim, dtype): + self.rms_eps = 1e-6 + self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype) + self.fp8_quant = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + ) + + def allreduce_rmsnorm( + self, input_tensor: torch.Tensor, residual: torch.Tensor | None + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + return self.rms_norm(allreduce_out, residual) + + def allreduce_rmsnorm_fp8_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + scale_factor: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out + else: + rms_out, residual_out = rms_out + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out, residual_out + + def allreduce_rmsnorm_fp4_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, output_scale + else: + rms_out, residual_out = rms_out + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, residual_out, output_scale + + +def create_test_tensors( + num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: FlashInferFusedAllReduceParams | None, + quant_modes: set[str], + no_oneshot: bool, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + + if "none" in quant_modes: + # Standard AllReduce + RMSNorm + for custom_op in ["-rms_norm", "+rms_norm"]: + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) + ): + try: + suffix = ( + "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" + ) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm, + input_tensor, + residual=residual, + ) + results[f"standard_allreduce_{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results[f"standard_allreduce_{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float( + "inf" + ) + + if "fp8" in quant_modes: + # Standard AllReduce + RMSNorm + FP8 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: + suffix += ( + "_custom_quant_fp8" + if "+" in quant_fp8_custom_op + else "_native_quant_fp8" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op, quant_fp8_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results[f"standard_allreduce{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results[f"standard_allreduce{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=["-rms_norm", "-quant_fp8"] + ) + ) + ): + try: + standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + float("inf") + ) + + if "fp4" in quant_modes and current_platform.has_device_capability(100): + # Standard AllReduce + RMSNorm + FP4 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + quant_out=fp4_quant_out, + input_global_scale=scale_fp4, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + float("inf") + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results( + results_dict, + num_tokens, + hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, +): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print( + f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_modes={','.join(sorted(list(quant_modes)))}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + lines: list[str] = [] + lines.append("# FlashInfer Fused Collective Operations Benchmark Results") + lines.append("") + lines.append(f"**World Size:** {world_size} ") + lines.append(f"**Hidden Dimension:** {args.hidden_dim} ") + lines.append(f"**Warmup Iterations:** {args.warmup} ") + lines.append(f"**Benchmark Trials:** {args.trials} ") + modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A" + lines.append(f"**Quantization Modes:** {modes} ") + lines.append("") + lines.append("---") + lines.append("") + + for entry in all_results: + num_tokens = entry["num_tokens"] + dtype = entry["dtype"] + use_residual = entry["use_residual"] + results_dict = entry["results"] + input_size_mb = entry["input_size_mb"] + residual_str = "with residual" if use_residual else "no residual" + + lines.append( + f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}" + ) + lines.append(f"**Input Size:** {input_size_mb:.2f} MB") + lines.append("") + + prepared = prepare_results_with_speedups(results_dict) + # Build DataFrame for markdown export + rows = [ + { + "Operation": r["operation"].replace("_", " ").title(), + "Time (ms)": r["time_str"], + "Speedup": r["speedup_str"], + } + for r in prepared + ] + df = pd.DataFrame(rows) + if df.empty: + lines.append("No results.") + else: + lines.append(df.to_markdown(index=False)) + lines.append("") + + return "\n".join(lines) + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "a") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--num-tokens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Numbers of tokens to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + parser.add_argument( + "--quant-modes", + type=str, + default="none,fp8,fp4", + help=( + "Comma-separated quantization modes to run: none, fp8, fp4. " + "Default: none,fp8,fp4" + ), + ) + + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Parse quantization modes + valid_quant_modes = {"none", "fp8", "fp4"} + raw_modes = [ + m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip() + ] + quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"} + invalid = sorted(list(quant_modes - valid_quant_modes)) + if invalid: + raise ValueError( + f"Invalid --quant-modes entries: {','.join(invalid)}. " + f"Valid options are: {','.join(sorted(valid_quant_modes))}." + ) + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes)))) + if flashinfer_comm is not None: + logger.info( + "FlashInfer available - will benchmark fused operations", + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + + configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for num_tokens, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s", + num_tokens, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + num_tokens, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_modes=quant_modes, + no_oneshot=args.no_oneshot, + ) + + # Store results for markdown export + if rank == 0: + # Calculate input size in MB + input_size_mb = ( + num_tokens * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) + all_results.append( + { + "num_tokens": num_tokens, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_modes": sorted(list(quant_modes)), + "input_size_mb": input_size_mb, + "results": results, + } + ) + + print_results( + results, + num_tokens, + args.hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 58026e7e7e78..4b910bc28579 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -71,6 +71,13 @@ class ModelBackendTestCase(NamedTuple): attention_fusions=0, allreduce_fusions=65, ), + ModelBackendTestCase( + model_name="Qwen/Qwen3-30B-A3B", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=97, + ), ] elif current_platform.is_rocm(): diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7294ddce64ba..69d4606d73eb 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,7 +9,6 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -450,34 +449,41 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) +# Max size of the input tensor per world size per device capability +# to use flashinfer fused allreduce +FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = { + 90: { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 0.5, # 0.5MB + }, + 100: { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, +} + +# Max size of the input tensor per world size per device capability +# to use flashinfer one shot fused allreduce +# OneShot max size is at most 64MB / world size (FlashInfer restriction) +_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = { + 90: { + 2: 32, # 32MB + 4: 2, # 2MB + 8: 0.5, # 0.5MB + }, + 100: { + 2: 32, # 32MB + 4: 4, # 4MB + 8: 1, # 1MB + }, +} + + if flashinfer_comm is not None: _FI_WORKSPACE_TENSOR = None - MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB - } - - try: - _FI_MAX_SIZES.update( - { - int(k): int(float(v) * MiB) - for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - } - ) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) - ) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, @@ -491,7 +497,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -500,12 +505,20 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, - ) - if use_flashinfer: + + if num_tokens <= max_token_num: + device_capability = current_platform.get_device_capability().to_int() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) + # Use one shot if no max size for one shot is specified + use_oneshot = ( + max_one_shot_size_mb is None + or current_tensor_size <= max_one_shot_size_mb * MiB + ) + assert _FI_WORKSPACE_TENSOR is not None, ( "Flashinfer must be enabled when using flashinfer" ) @@ -532,7 +545,7 @@ def call_trtllm_fused_allreduce_norm( hidden_dim=allreduce_in.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, launch_with_pdl=launch_with_pdl, - use_oneshot=True, + use_oneshot=use_oneshot, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, pattern_code=pattern_code, @@ -545,7 +558,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: + if scale_factor is not None and scale_out is None: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -568,15 +581,10 @@ def call_trtllm_fused_allreduce_norm( norm_out = allreduce_out else: torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None: - if scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - else: - torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor - ) + if scale_factor is not None and scale_out is not None: + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) if scale_factor is None or norm_out is not None: # we need to return allreduce output # in cases of non quant fused AR + RMS norm @@ -595,7 +603,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -629,7 +636,6 @@ def __init__( world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -637,9 +643,7 @@ def __init__( self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True - self.use_oneshot = False self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -649,7 +653,6 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } @@ -1119,23 +1122,35 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.pass_config.flashinfer_max_size( + self.tp_size + ) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) - // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + element_size = 4 if use_fp32_lamport else 2 + self.max_token_num = max_size // (self.hidden_dim * element_size) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens + ) + logger.debug_once( + f"Flashinfer max size: {max_size // (1024 * 1024)} MB," + "Maximal number of tokens used by " + f"Flashinfer Allreduce Fusion: {self.max_token_num}", + scope="global", ) + self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1148,10 +1163,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + max_token_num=self.max_token_num, ) self.register_patterns() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c84a060922e3..92cf16f259fe 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -111,11 +111,52 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: float | None = None + """The threshold of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + float in MB. + Unspecified will fallback to default values + which are compute capability and world size dependent. + FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { + 90: { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + 100: { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + }, where key is the device capability""" # TODO(luka) better pass enabling system. + def flashinfer_max_size(self, world_size: int) -> int | None: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Returns None if world size + is not supported by configs as it's not supported by flashinfer. + """ + + MiB = 1024 * 1024 + max_size_mb = self.fi_allreduce_fusion_max_size_mb + if max_size_mb is None: + max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) + + return int(max_size_mb * MiB) if max_size_mb is not None else None + + @staticmethod + def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: + from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + return {} + return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get( + current_platform.get_device_capability().to_int(), {} + ) + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -136,6 +177,11 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) + if self.enable_fi_allreduce_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Allreduce + rms norm + quant (fp8) fusion might not work" + ) @config diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f86a93e30003..27ad9c8fd1c2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2356,6 +2356,16 @@ def forward_native( value=0.0, ) + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + if self.shared_experts is None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -2366,7 +2376,14 @@ def forward_native( fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name ) - return fused_output[..., :og_hidden_states] + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -2379,8 +2396,8 @@ def forward_native( hidden_states, router_logits, self.layer_name ) return ( - shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states], + reduce_output(shared_output)[..., :og_hidden_states], + reduce_output(fused_output)[..., :og_hidden_states], ) def forward_cuda( @@ -2667,31 +2684,21 @@ def forward_impl( assert isinstance(final_hidden_states, tuple) final_hidden_states, zero_expert_result = final_hidden_states - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: + def combine_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - return states if self.shared_experts is not None: return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), + final_hidden_states[0], + combine_output(final_hidden_states[1]), ) elif self.zero_expert_num is not None and self.zero_expert_num > 0: assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result + return (combine_output(final_hidden_states), zero_expert_result) else: - return reduce_output(final_hidden_states) + return combine_output(final_hidden_states) @classmethod def make_expert_params_mapping( From b30372cbd045aeac50833cd6fe6084d2edd5252c Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 15:34:18 -0800 Subject: [PATCH 32/49] [Perf] Move gc.freeze logic from EngineCoreProc to EngineCore for better coverage (#27896) Signed-off-by: Jialin Ouyang --- vllm/benchmarks/serve.py | 5 ++--- vllm/distributed/parallel_state.py | 3 +++ vllm/entrypoints/openai/api_server.py | 6 ++---- vllm/utils/gc_utils.py | 15 +++++++++++++++ vllm/v1/engine/core.py | 15 ++++++++------- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index e58cf5911282..0e9b0fbe2c02 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -19,7 +19,6 @@ import argparse import asyncio import contextlib -import gc import importlib.util import json import os @@ -49,6 +48,7 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.gc_utils import freeze_gc_heap MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1414,8 +1414,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: percentile_metrics: str = args.percentile_metrics or default_percentile_metrics # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() + freeze_gc_heap() benchmark_result = await benchmark( task_type=task_type, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a9b01e82562b..c78e6a32733c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1483,6 +1483,9 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + # Ensure all objects are not freezed before cleanup + gc.unfreeze() + destroy_model_parallel() destroy_distributed_environment() if shutdown_ray: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c8c8d5c034d5..51191879e478 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import asyncio -import gc import hashlib import importlib import inspect @@ -118,6 +116,7 @@ from vllm.tasks import POOLING_TASKS from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.v1.engine.exceptions import EngineDeadError @@ -153,8 +152,7 @@ async def _force_log(): # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() + freeze_gc_heap() try: yield finally: diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 4dd85ef26f34..160ac9ac263a 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -89,6 +89,21 @@ def handle(self, phase: str, info: dict[str, int]) -> None: ) +def freeze_gc_heap() -> None: + """ + Freeze all objects tracked by the garbage collector. It should be invoked + after server init / warmup, to reduce GC overhead from static objects + during serving time. + """ + # Ensure all static objects are pushed down to the oldest generation for + # freeze + gc.collect(0) + gc.collect(1) + gc.collect(2) + # Freeze all GC tracked objects + gc.freeze() + + def maybe_attach_gc_debug_callback() -> None: """ Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c3efd52130cc..ffb5232e770d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc import os import queue import signal @@ -27,7 +26,10 @@ from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.utils.gc_utils import maybe_attach_gc_debug_callback +from vllm.utils.gc_utils import ( + freeze_gc_heap, + maybe_attach_gc_debug_callback, +) from vllm.utils.hashing import get_hash_fn_by_name from vllm.utils.network_utils import make_zmq_socket from vllm.utils.system_utils import decorate_logs, set_process_title @@ -197,6 +199,10 @@ def __init__( self.step if self.batch_queue is None else self.step_with_batch_queue ) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + freeze_gc_heap() + def _initialize_kv_caches( self, vllm_config: VllmConfig ) -> tuple[int, int, KVCacheConfig]: @@ -651,11 +657,6 @@ def __init__( assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - # Mark the startup heap as static so that it's ignored by GC. - # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() - # If enable, attach GC debugger after static variable freeze. maybe_attach_gc_debug_callback() From a5a790eea6035760c71eae1861c1e5f369bc6d08 Mon Sep 17 00:00:00 2001 From: Adrian Abeyta Date: Mon, 10 Nov 2025 17:42:37 -0600 Subject: [PATCH 33/49] [Bugfix] Ensure calculated KV scales are applied in attention. (#27232) Signed-off-by: adabeyta --- .buildkite/test-pipeline.yaml | 7 +++++-- tests/compile/test_full_graph.py | 10 ++++++++-- vllm/attention/layer.py | 29 +++++++---------------------- vllm/v1/worker/gpu_model_runner.py | 19 +++++++++---------- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3152cd6488f3..a0d2076199b1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -471,8 +471,8 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph.py - # Limit to no custom ops to reduce running time + - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time # Wrap with quotes to escape yaml and avoid starting -k string with a - - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" @@ -951,10 +951,13 @@ steps: - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/test_fusions_e2e.py + - tests/compile/test_full_graph.py commands: - nvidia-smi # Run all e2e fusion tests - pytest -v -s tests/compile/test_fusions_e2e.py + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0ad8c17d8668..71f90f6d8d3e 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -183,8 +183,14 @@ def test_custom_compile_config( "compilation_mode", [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) -def test_fp8_kv_scale_compile(compilation_mode: int): - model = "Qwen/Qwen2-0.5B" +@pytest.mark.parametrize( + "model", + [ + "Qwen/Qwen2-0.5B", # Standard attention model + "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ], +) +def test_fp8_kv_scale_compile(compilation_mode: int, model: str): model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..96272981692c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -745,6 +745,9 @@ def forward( k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -752,12 +755,6 @@ def forward( attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # Mirror Attention.forward scale calculation path - if self.calculate_kv_scales and getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): - self.calc_kv_scales(q, kv_c_normed, k_pe) - if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.impl.forward( @@ -786,14 +783,6 @@ def forward( ) return output else: - # We can still access forward context to check calculation flag - if self.calculate_kv_scales: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - if getattr(attn_metadata, "enable_kv_scales_calculation", False): - self.calc_kv_scales(q, kv_c_normed, k_pe) return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, @@ -881,17 +870,13 @@ def maybe_calc_kv_scales( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] - if attn_metadata is None or not getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): + # Only calculate if the layer's calculate_kv_scales flag is True + # This flag gets set to False after the first forward pass + if not self.calculate_kv_scales: return - self = forward_context.no_compile_layers[layer_name] self.calc_kv_scales(query, key, value) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9403b5756e05..6fccf2ea2f47 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -279,6 +279,9 @@ def __init__( # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + + # Always set to false after the first forward pass + self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -2625,16 +2628,12 @@ def execute_model( ) # Set cudagraph mode to none if calc_kv_scales is true. - if attn_metadata is not None: - metadata_list = ( - attn_metadata.values() - if isinstance(attn_metadata, dict) - else [attn_metadata] - ) - if any( - getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list - ): - cudagraph_runtime_mode = CUDAGraphMode.NONE + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False # Run the model. # Use persistent buffers for CUDA graphs. From 0bf29fadf5f8b28817fbccb037fb70adaef3f7f1 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 10 Nov 2025 17:57:41 -0600 Subject: [PATCH 34/49] [Test] Remove old non-varlen FA2 test (#28420) Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_flash_attn.py | 119 --------------------- 1 file changed, 119 deletions(-) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 18995545552e..6e5468969bf2 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -9,7 +9,6 @@ from vllm.vllm_flash_attn import ( fa_version_unsupported_reason, flash_attn_varlen_func, - flash_attn_with_kvcache, is_fa_version_supported, ) @@ -83,124 +82,6 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) -@pytest.mark.parametrize("fa_version", [2, 3]) -@pytest.mark.parametrize("q_dtype", QDTYPES) -@torch.inference_mode() -def test_flash_attn_with_paged_kv( - use_out: bool, - kv_lens: list[int], - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: float | None, - num_blocks: int, - sliding_window: int | None, - fa_version: int, - q_dtype: torch.dtype | None, -) -> None: - torch.set_default_device("cuda") - if not is_fa_version_supported(fa_version): - pytest.skip( - f"Flash attention version {fa_version} not supported due " - f'to: "{fa_version_unsupported_reason(fa_version)}"' - ) - if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip( - "Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type" - ) - - current_platform.seed_everything(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn( - num_blocks, block_size, num_kv_heads, head_size, dtype=dtype - ) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint( - 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 - ) - - q = query.unsqueeze(1) - out = torch.empty_like(q) if use_out else None - - maybe_quantized_query = q - maybe_quantized_key_cache = key_cache - maybe_quantized_value_cache = value_cache - q_descale = None - k_descale = None - v_descale = None - if q_dtype is not None: - # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = q.to(q_dtype) - maybe_quantized_key_cache = key_cache.to(q_dtype) - maybe_quantized_value_cache = value_cache.to(q_dtype) - - scale_shape = (num_seqs, num_kv_heads) - q_descale = torch.ones(scale_shape, dtype=torch.float32) - k_descale = torch.ones(scale_shape, dtype=torch.float32) - v_descale = torch.ones(scale_shape, dtype=torch.float32) - - output = flash_attn_with_kvcache( - q=maybe_quantized_query, - k_cache=maybe_quantized_key_cache, - v_cache=maybe_quantized_value_cache, - out=out, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - window_size=window_size, - fa_version=fa_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - output = output if not use_out else out - output = output.squeeze(1) - - atol, rtol = 1.5e-2, 1e-2 - if q_dtype is not None: - atol, rtol = 1.5e-1, 1.5e-1 - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window, - ) - ( - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), - f"{torch.max(torch.abs(output - ref_output))}", - ) - - @pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize( "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] From 35d801f13fa5bd79ae74707388b1fa4e1caf9ba5 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:08:40 -0500 Subject: [PATCH 35/49] [Feature] Refactor batch invariant fp8 DeepGEMM (#27606) Signed-off-by: yewentao256 --- .../model_executor/layers/quantization/fp8.py | 98 +++---------------- 1 file changed, 11 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f5fc750baaea..c7d5b251cf4e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -43,7 +43,6 @@ QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -95,11 +94,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( - fp8_gemm_nt, get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear, ) from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.import_utils import has_deep_gemm @@ -554,83 +551,19 @@ def apply( # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): - # Call is_deep_gemm_supported() ahead of time for torch.compile - # dynamo has trouble tracing through - if self.block_quant and should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight, self.use_deep_gemm - ): - # use group quant consistent with block size across K - assert self.act_q_group_shape is not None - q_input, input_scale = QuantFP8( - False, - self.act_q_group_shape, - column_major_scales=True, - )(x) - - output_2d = torch.empty( - (q_input.shape[0], layer.weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device, - ) - fp8_gemm_nt( - (q_input, input_scale), - (layer.weight, layer.weight_scale), - output_2d, - ) - if bias is not None: - output_2d = output_2d + bias - return output_2d - - # Dequantize FP8 weights to BF16 - weight_fp8 = layer.weight.to(torch.bfloat16) - weight_scale = layer.weight_scale.to(torch.bfloat16) - - # Handle different quantization granularities if self.block_quant: - # Block-wise quantization: - # - Weight is NOT transposed, shape is [N, K] (output_size, input_size) - # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!) assert self.weight_block_size is not None - block_n, block_k = self.weight_block_size # Note: order is [N, K] - - N, K = weight_fp8.shape - - # determine expected number of blocks along N and K - num_blocks_n = (N + block_n - 1) // block_n - num_blocks_k = (K + block_k - 1) // block_k - - # scale layout may be [num_blocks_n, num_blocks_k] - # or [num_blocks_k, num_blocks_n] depending on backend - if weight_scale.dim() != 2: - raise RuntimeError( - f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}" - ) - - scale_rows, scale_cols = weight_scale.shape - if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): - if num_blocks_n == num_blocks_k: - # ambiguous square case, warn and skip transpose - logger.warning( - "Batch-invariant FP8: square block-scale %dx%d; " - "skipping transpose to avoid misorientation.", - scale_rows, - scale_cols, - ) - else: - # clear KN -> transpose to NK - weight_scale = weight_scale.t() - - # Expand scale to match weight dimensions - # scale_expanded should have shape [N, K] - scale_expanded = weight_scale.repeat_interleave( - block_n, dim=0 - ).repeat_interleave(block_k, dim=1) - # Trim to exact weight size (in case of padding) - scale_expanded = scale_expanded[:N, :K] - weight_bf16 = weight_fp8 * scale_expanded + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) else: - # Per-tensor quantization: weight IS transposed to [K, N] - # scale should be scalar or [1] or per-output-channel [N] + # per-tensor/channel: dequant to BF16 and run GEMM + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) if weight_scale.numel() == 1: # Per-tensor: simple scalar multiplication weight_bf16 = weight_fp8 * weight_scale @@ -649,16 +582,7 @@ def apply( else: # Fallback weight_bf16 = weight_fp8 * weight_scale - - # For block quant, weight is [N, K], for per-tensor it's [K, N] - # F.linear expects weight to be [N, K], so: - if self.block_quant: - # Already in correct shape [N, K] - output = torch.nn.functional.linear(x, weight_bf16, bias) - else: - # Need to transpose back: [K, N] -> [N, K] - output = torch.nn.functional.linear(x, weight_bf16.t(), bias) - return output + return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: return apply_fp8_marlin_linear( From 39029d519276fddbe0c36440e0eefcdda069b969 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 10 Nov 2025 20:36:29 -0500 Subject: [PATCH 36/49] [CI/Test Fix] Fix CP tests on Blackwell (#28404) Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/distributed/test_context_parallel.py | 12 ++++++++++++ vllm/attention/ops/common.py | 1 - 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 7f8e77a75621..3576efca591c 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -14,6 +14,7 @@ from typing import Literal, NamedTuple import pytest +import torch from vllm.config.model import RunnerOption from vllm.logger import init_logger @@ -254,6 +255,17 @@ def test_cp_generation( test_options: CPTestOptions, num_gpus_available, ): + if ( + model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat" + and torch.cuda.get_device_capability() < (9, 0) + ): + pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") + if ( + model_id == "bigcode/gpt_bigcode-santacoder" + and torch.cuda.get_device_capability() != (9, 0) + ): + pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") + _compare_cp_with_tp( model_id, parallel_setup, diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 75fdcb8f48b2..2cbb5c91cc3b 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -195,7 +195,6 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) - assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) if return_lse: From de540c0354b9ecfa979c917a4599f8030d4105be Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 10 Nov 2025 21:29:48 -0500 Subject: [PATCH 37/49] [Feature] Add env var `VLLM_MOE_USE_DEEP_GEMM` (#28422) Signed-off-by: yewentao256 --- vllm/envs.py | 6 ++++++ .../compressed_tensors/compressed_tensors_moe.py | 10 +++++++++- vllm/model_executor/layers/quantization/fp8.py | 2 +- vllm/model_executor/warmup/deep_gemm_warmup.py | 3 +++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 30c62e90e9fb..9421488051e5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -147,6 +147,7 @@ VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True + VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", @@ -1116,6 +1117,10 @@ def get_vllm_port() -> int | None: ), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), + # Allow use of DeepGemm specifically for MoE fused ops (overrides only MoE). + "VLLM_MOE_USE_DEEP_GEMM": lambda: bool( + int(os.getenv("VLLM_MOE_USE_DEEP_GEMM", "1")) + ), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) @@ -1569,6 +1574,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_MOE_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP16", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d32ae6674ee6..59567f2ca13c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -966,10 +966,18 @@ def select_gemm_impl( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), ) else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) + return TritonOrDeepGemmExperts( + self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c7d5b251cf4e..83d136600b77 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -158,7 +158,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: return Fp8MoeBackend.MARLIN # deepGEMM on supported platforms with block-quantized weights - if envs.VLLM_USE_DEEP_GEMM and block_quant: + if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: if not has_deep_gemm(): logger.warning_once("DeepGEMM backend requested but not available.") elif is_deep_gemm_supported(): diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index bdcebd498ef0..e0c584df8760 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -148,6 +148,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: + if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM): + return False + if not isinstance(module, FusedMoE): return False From f2d9ad0620d9aa71481527dcfafdb8357da00470 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 10 Nov 2025 19:53:24 -0700 Subject: [PATCH 38/49] Only register rocm_aiter_ops if aiter is found (#28428) Signed-off-by: mgoin --- vllm/_aiter_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 9a4b5f3399be..8d35aa65738b 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -938,4 +938,5 @@ def shuffle_weights( return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) -rocm_aiter_ops.register_ops_once() +if IS_AITER_FOUND: + rocm_aiter_ops.register_ops_once() From 57201a6a4c53bbd6adb9a4b702c95d5f480161d5 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Mon, 10 Nov 2025 18:57:12 -0800 Subject: [PATCH 39/49] Fix rotary embedding benchmark script (#28323) Signed-off-by: Xin Yang --- benchmarks/kernels/benchmark_rope.py | 154 +++++++++++---------------- 1 file changed, 64 insertions(+), 90 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 29ef6409bb16..074b7a440b61 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,97 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate +import itertools -import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope -from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser +batch_size_range = [2**i for i in range(0, 8, 2)] +seq_len_range = [2**i for i in range(6, 10, 1)] +num_heads_range = [32, 48] +configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range)) -def benchmark_rope_kernels_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: int | None, - dtype: torch.dtype, - seed: int, - device: str, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - # silulating serving 4 LoRAs - scaling_factors = [1, 2, 4, 8] - # batched RoPE can take multiple scaling factors - batched_rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": tuple(scaling_factors)}, + +def get_benchmark(head_size, rotary_dim, is_neox_style, device): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "num_heads"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "flashinfer", "vllm"], + line_names=["PyTorch", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}", + args={}, + ) ) - # non-batched RoPE takes only one scaling factor, we create multiple - # instances to simulate the same behavior - non_batched_ropes: list[RotaryEmbedding] = [] - for scaling_factor in scaling_factors: - non_batched_ropes.append( - get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": (scaling_factor,)}, - ) + def benchmark(batch_size, seq_len, num_heads, provider): + dtype = torch.bfloat16 + max_position = 8192 + base = 10000 + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope = rope.to(dtype=dtype, device=device) + cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) + + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + query = torch.randn( + (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device ) + key = torch.randn_like(query) - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + quantiles = [0.5, 0.2, 0.8] - # create query offsets for batched RoPE, we concat multiple kv cache - # together and each query needs to find the right kv cache of its type - offset_map = torch.tensor( - list( - accumulate( - [0] - + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_native(positions, query.clone(), key.clone()), + quantiles=quantiles, ) - ) - ) - query_types = torch.randint( - 0, len(scaling_factors), (batch_size, seq_len), device=device - ) - # map query types to offsets - query_offsets = offset_map[query_types] - # the kernel takes flattened offsets - flatten_offsets = query_offsets.flatten() + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox_style, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_cuda(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms - # batched queries of the same type together for non-batched RoPE - queries = [query[query_types == i] for i in range(len(scaling_factors))] - keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qkr = zip(queries, keys, non_batched_ropes) - # synchronize before start timing - torch.cuda.synchronize() - with nvtx.annotate("non-batched", color="yellow"): - for q, k, r in packed_qkr: - r.forward(positions, q, k) - torch.cuda.synchronize() - with nvtx.annotate("batched", color="green"): - batched_rope.forward(positions, query, key, flatten_offsets) - torch.cuda.synchronize() + return benchmark if __name__ == "__main__": @@ -116,17 +95,12 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument( "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" ) + parser.add_argument("--save-path", type=str, default="./configs/rope/") args = parser.parse_args() - print(args) - benchmark_rope_kernels_multi_lora( - is_neox_style=args.is_neox_style, - batch_size=args.batch_size, - seq_len=args.seq_len, - num_heads=args.num_heads, - head_size=args.head_size, - rotary_dim=args.rotary_dim, - dtype=getattr(torch, args.dtype), - seed=args.seed, - device=args.device, + # Get the benchmark function + benchmark = get_benchmark( + args.head_size, args.rotary_dim, args.is_neox_style, args.device ) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) From 8d706cca903a008169e7ac8f1dc1f65c8ffd85c0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 10 Nov 2025 19:41:23 -0800 Subject: [PATCH 40/49] [Misc] FlattenLogprobs -> FlatLogprobs (#28335) --- tests/samplers/test_logprobs.py | 16 +++++-------- tests/test_logprobs.py | 40 ++++++++++++++++----------------- vllm/envs.py | 8 +++---- vllm/logprobs.py | 26 ++++++++++----------- 4 files changed, 43 insertions(+), 47 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 87f5d40ac1da..c9d227599cde 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -4,7 +4,7 @@ import pytest from vllm import SamplingParams -from vllm.logprobs import FlattenLogprobs +from vllm.logprobs import FlatLogprobs MODELS = ["distilbert/distilgpt2"] MAX_TOKENS = 5 @@ -16,17 +16,17 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("greedy", [True, False]) -@pytest.mark.parametrize("flatten_logprobs", [True, False]) +@pytest.mark.parametrize("flat_logprobs", [True, False]) def test_ranks( vllm_runner, model, dtype, greedy, - flatten_logprobs, + flat_logprobs, example_prompts, monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0") with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] @@ -44,12 +44,8 @@ def test_ranks( decode_tokens, _, decode_logprobs, prompt_logprobs = result # Ensure the return type of logprobs is accurate - assert isinstance( - prompt_logprobs, FlattenLogprobs if flatten_logprobs else list - ) - assert isinstance( - decode_logprobs, FlattenLogprobs if flatten_logprobs else list - ) + assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list) + assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list) ######################## # Check prompt logprobs diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py index 1799d3638178..d26a460d2bca 100644 --- a/tests/test_logprobs.py +++ b/tests/test_logprobs.py @@ -5,7 +5,7 @@ import pytest from vllm.logprobs import ( - FlattenLogprobs, + FlatLogprobs, Logprob, LogprobsOnePosition, append_logprobs_for_next_position, @@ -14,8 +14,8 @@ ) -def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") +def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") prompt_logprobs = create_prompt_logprobs() assert isinstance(prompt_logprobs, list) @@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") +def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") prompt_logprobs = create_prompt_logprobs() - assert isinstance(prompt_logprobs, FlattenLogprobs) + assert isinstance(prompt_logprobs, FlatLogprobs) assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.end_indices == [0] assert len(prompt_logprobs.token_ids) == 0 @@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert prompt_logprobs[0] == dict() sample_logprobs = create_sample_logprobs() - assert isinstance(sample_logprobs, FlattenLogprobs) + assert isinstance(sample_logprobs, FlatLogprobs) assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.end_indices) == 0 assert len(sample_logprobs.token_ids) == 0 @@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_append_logprobs_for_next_position_none_flatten( +def test_append_logprobs_for_next_position_none_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten( ] -def test_append_logprobs_for_next_position_flatten( +def test_append_logprobs_for_next_position_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten( rank=11, num_logprobs=-1, ) - assert isinstance(logprobs, FlattenLogprobs) + assert isinstance(logprobs, FlatLogprobs) assert logprobs.start_indices == [0, 1] assert logprobs.end_indices == [1, 3] assert logprobs.token_ids == [1, 2, 3] @@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten( } -def test_flatten_logprobs_append() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_append() -> None: + logprobs = FlatLogprobs() logprobs.append(LOGPROBS_ONE_POSITION_0) logprobs.append(LOGPROBS_ONE_POSITION_1) assert logprobs.start_indices == [0, 1] @@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None: assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"] -def test_flatten_logprobs_extend() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_extend() -> None: + logprobs = FlatLogprobs() # Extend with list[LogprobsOnePosition] logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) assert logprobs.start_indices == [0, 3] @@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.ranks == [40, 50, 60, 10] assert logprobs.decoded_tokens == ["40", "50", "60", "10"] - other_logprobs = FlattenLogprobs() + other_logprobs = FlatLogprobs() other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) - # Extend with another FlattenLogprobs + # Extend with another FlatLogprobs logprobs.extend(other_logprobs) assert logprobs.start_indices == [0, 3, 4, 6] assert logprobs.end_indices == [3, 4, 6, 7] @@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"] -def test_flatten_logprobs_access() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_access() -> None: + logprobs = FlatLogprobs() logprobs.extend( [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] ) diff --git a/vllm/envs.py b/vllm/envs.py index 9421488051e5..52178e5f5250 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -223,7 +223,7 @@ VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" - VLLM_FLATTEN_LOGPROBS: bool = False + VLLM_FLAT_LOGPROBS: bool = False def get_default_cache_root(): @@ -1481,11 +1481,11 @@ def get_vllm_port() -> int | None: "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), - # Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than + # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than # the original list[dict[int, Logprob]] approach. # After enabled, PromptLogprobs and SampleLogprobs would populated as - # FlattenLogprobs. - "VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), + # FlatLogprobs. + "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/logprobs.py b/vllm/logprobs.py index bf66e5f75c79..a34398db2c96 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -30,16 +30,16 @@ class Logprob: @dataclass -class FlattenLogprobs(MutableSequence[LogprobsOnePosition]): +class FlatLogprobs(MutableSequence[LogprobsOnePosition]): """ - Flatten logprobs of a request into multiple primitive type lists. + Flat logprobs of a request into multiple primitive type lists. Compared to list[dict[int, Logprob]], this data structure reduced GC overhead significantly. As it flattened logprob information for all positions and ranks in to multiple primitive type lists (i.e. logprobs, token_ids, ranks per token_ids, decoded_tokens). So regardless of the sequence length and top_logprobs setup, - FlattenLogprobs would only introduce a constant amount of objects. + FlatLogprobs would only introduce a constant amount of objects. As each position might contains different amount of ranks, start_indices_per_position would be used to access the logprob ranges @@ -107,7 +107,7 @@ def __len__(self) -> int: def __getitem__(self, position: int) -> LogprobsOnePosition: ... @overload - def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ... + def __getitem__(self, s: slice, /) -> "FlatLogprobs": ... def __getitem__(self, index: int | slice): """Extracts logprobs of a given position or slice""" @@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice): elif isinstance(index, slice): min_index = self.start_indices[index][0] max_index = self.end_indices[index][-1] - return FlattenLogprobs( + return FlatLogprobs( # Shift updated start_indices and end_indices to # be 0-indexed start_indices=[i - min_index for i in self.start_indices[index]], @@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice): raise TypeError(f"Invalid index type: {type(index)}") def __setitem__(self, item, value) -> None: - raise TypeError("Cannot set logprobs in FlattenLogprobs") + raise TypeError("Cannot set logprobs in FlatLogprobs") def __delitem__(self, item) -> None: - raise TypeError("Cannot delete logprobs from FlattenLogprobs") + raise TypeError("Cannot delete logprobs from FlatLogprobs") def insert(self, item) -> None: - raise TypeError("Cannot insert logprobs to FlattenLogprobs") + raise TypeError("Cannot insert logprobs to FlatLogprobs") def __iter__(self) -> Iterator[LogprobsOnePosition]: """ @@ -156,14 +156,14 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]: # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None] +PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None] # {token_id -> logprob} for each sequence group. -SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition] +SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] def create_prompt_logprobs() -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" - logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs @@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs: def create_sample_logprobs() -> SampleLogprobs: """Creates a container to store decode logprobs for a request""" - return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] def append_logprobs_for_next_position( @@ -191,7 +191,7 @@ def append_logprobs_for_next_position( topk_ranks = range(1, num_logprobs + 1) ranks = itertools.chain((rank,), topk_ranks) - if isinstance(request_logprobs, FlattenLogprobs): + if isinstance(request_logprobs, FlatLogprobs): request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) else: request_logprobs.append( From bca74e32b7ef03515cda508ba88151e2e547bdc9 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 20:57:01 -0800 Subject: [PATCH 41/49] [Frontend] Add sagemaker_standards dynamic lora adapter and stateful session management decorators to vLLM OpenAI API server (#27892) Signed-off-by: Zuyi Zhao Signed-off-by: Shen Teng Co-authored-by: Shen Teng Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- requirements/common.txt | 1 + tests/entrypoints/sagemaker/__init__.py | 0 tests/entrypoints/sagemaker/conftest.py | 58 ++ .../test_sagemaker_handler_overrides.py | 734 ++++++++++++++++++ .../sagemaker/test_sagemaker_lora_adapters.py | 171 ++++ .../test_sagemaker_middleware_integration.py | 346 +++++++++ .../test_sagemaker_stateful_sessions.py | 153 ++++ vllm/entrypoints/dynamic_lora.py | 57 ++ vllm/entrypoints/openai/api_server.py | 100 +-- vllm/entrypoints/sagemaker/__init__.py | 4 + vllm/entrypoints/sagemaker/routes.py | 72 ++ 11 files changed, 1613 insertions(+), 83 deletions(-) create mode 100644 tests/entrypoints/sagemaker/__init__.py create mode 100644 tests/entrypoints/sagemaker/conftest.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py create mode 100644 vllm/entrypoints/dynamic_lora.py create mode 100644 vllm/entrypoints/sagemaker/__init__.py create mode 100644 vllm/entrypoints/sagemaker/routes.py diff --git a/requirements/common.txt b/requirements/common.txt index 8009581f62a4..90efb79a845d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -49,3 +49,4 @@ cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 +model-hosting-container-standards < 1.0.0 \ No newline at end of file diff --git a/tests/entrypoints/sagemaker/__init__.py b/tests/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/sagemaker/conftest.py b/tests/entrypoints/sagemaker/conftest.py new file mode 100644 index 000000000000..4c859c2527d2 --- /dev/null +++ b/tests/entrypoints/sagemaker/conftest.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared fixtures and utilities for SageMaker tests.""" + +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# Model name constants used across tests +MODEL_NAME_ZEPHYR = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct" +LORA_ADAPTER_NAME_SMOLLM = "jekunz/smollm-135m-lora-fineweb-faroese" + +# SageMaker header constants +HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id" +HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id" +HEADER_SAGEMAKER_NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id" + + +@pytest.fixture(scope="session") +def smollm2_lora_files(): + """Download LoRA files once per test session.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id=LORA_ADAPTER_NAME_SMOLLM) + + +@pytest.fixture(scope="module") +def basic_server_with_lora(smollm2_lora_files): + """Basic server fixture with standard configuration.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--max-lora-rank", + "256", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=envs) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def async_client(basic_server_with_lora: RemoteOpenAIServer): + """Async OpenAI client fixture for use with basic_server.""" + async with basic_server_with_lora.get_async_client() as async_client: + yield async_client diff --git a/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py new file mode 100644 index 000000000000..0d4f8e885824 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration tests for handler override functionality. + +Tests real customer usage scenarios: +- Using @custom_ping_handler and @custom_invocation_handler decorators + to override handlers +- Setting environment variables for handler specifications +- Writing customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions +- Priority: env vars > decorators > customer script files > framework + defaults + +Note: These tests focus on validating server responses rather than directly calling +get_ping_handler() and get_invoke_handler() to ensure full integration testing. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestHandlerOverrideIntegration: + """Integration tests simulating real customer usage scenarios. + + Each test simulates a fresh server startup where customers: + - Use @custom_ping_handler and @custom_invocation_handler decorators + - Set environment variables (CUSTOM_FASTAPI_PING_HANDLER, etc.) + - Write customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions + """ + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + self._clear_env_vars() + + def teardown_method(self): + """Cleanup after each test.""" + self._clear_env_vars() + + def _clear_caches(self): + """Clear handler registry and function loader cache.""" + try: + from model_hosting_container_standards.common.handler import ( + handler_registry, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + handler_registry.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + def _clear_env_vars(self): + """Clear SageMaker environment variables.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + # Clear SageMaker env vars + for var in [ + SageMakerEnvVars.SAGEMAKER_MODEL_PATH, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME, + ]: + os.environ.pop(var, None) + + # Clear FastAPI env vars + for var in [ + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER, + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER, + ]: + os.environ.pop(var, None) + except ImportError: + pass + + @pytest.mark.asyncio + async def test_customer_script_functions_auto_loaded(self): + """Test customer scenario: script functions automatically override + framework defaults.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with ping() and invoke() functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "customer_override", + "message": "Custom ping from customer script" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Custom response from customer script"], + "source": "customer_override" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Customer sets SageMaker environment variables to point to their script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Customer tests their server and sees their overrides work + # automatically + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their functions are used + assert ping_data["source"] == "customer_override" + assert ping_data["message"] == "Custom ping from customer script" + assert invoke_data["source"] == "customer_override" + assert invoke_data["predictions"] == [ + "Custom response from customer script" + ] + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_customer_decorator_usage(self): + """Test customer scenario: using @custom_ping_handler and + @custom_invocation_handler decorators.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +@sagemaker_standards.custom_ping_handler +async def my_ping(): + return { + "type": "ping", + "source": "customer_decorator" + } + +@sagemaker_standards.custom_invocation_handler +async def my_invoke(request: Request): + return { + "type": "invoke", + "source": "customer_decorator" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their handlers are used by the server + assert ping_data["source"] == "customer_decorator" + assert invoke_data["source"] == "customer_decorator" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_handler_priority_order(self): + """Test priority: @custom_ping_handler/@custom_invocation_handler + decorators vs script functions.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script with both decorator and regular functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +# Customer uses @custom_ping_handler decorator (higher priority than script functions) +@sagemaker_standards.custom_ping_handler +async def decorated_ping(): + return { + "status": "healthy", + "source": "ping_decorator_in_script", + "priority": "decorator" + } + +# Customer also has a regular function (lower priority than +# @custom_ping_handler decorator) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_function", + "priority": "function" + } + +# Customer has a regular invoke function +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke_function", + "priority": "function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # @custom_ping_handler decorator has higher priority than + # script function + assert ping_data["source"] == "ping_decorator_in_script" + assert ping_data["priority"] == "decorator" + + # Script function is used for invoke + assert invoke_data["source"] == "script_invoke_function" + assert invoke_data["priority"] == "function" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_environment_variable_script_loading(self): + """Test that environment variables correctly specify script location + and loading.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script in a specific directory + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "env_loaded_script", + "method": "environment_variable_loading" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Loaded via environment variables"], + "source": "env_loaded_script", + "method": "environment_variable_loading" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Test environment variable script loading + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Verify that the script was loaded via environment variables + assert ping_data["source"] == "env_loaded_script" + assert ping_data["method"] == "environment_variable_loading" + assert invoke_data["source"] == "env_loaded_script" + assert invoke_data["method"] == "environment_variable_loading" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_framework_default_handlers(self): + """Test that framework default handlers work when no customer + overrides exist.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + # Explicitly pass empty env_dict to ensure no SageMaker env vars are set + # This prevents pollution from previous tests + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + env_dict = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: "", + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: "", + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: "", + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: "", + } + except ImportError: + env_dict = {} + + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=env_dict) as server: + # Test that default ping works + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + + # Test that default invocations work + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + + @pytest.mark.asyncio + async def test_handler_env_var_override(self): + """Test CUSTOM_FASTAPI_PING_HANDLER and CUSTOM_FASTAPI_INVOCATION_HANDLER + environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with both env var handlers and script functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request, Response +import json + +async def env_var_ping_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var_ping", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def env_var_invoke_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var_invoke", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_ping", + "method": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke", + "method": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to override both handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_var_ping_handler" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_var_invoke_handler" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler override + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable should override script function + assert ping_data["method"] == "environment_variable" + assert ping_data["source"] == "env_var_ping" + + # Test invocation handler override + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable should override script function + assert invoke_data["method"] == "environment_variable" + assert invoke_data["source"] == "env_var_invoke" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_env_var_priority_over_decorator_and_script(self): + """Test that environment variables have highest priority over decorators + and script functions for both ping and invocation handlers.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with all three handler types for both ping and invocation + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request, Response +import json + +# Environment variable handlers (highest priority) +async def env_priority_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +async def env_priority_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +# Decorator handlers (medium priority) +@sagemaker_standards.custom_ping_handler +async def decorator_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +@sagemaker_standards.custom_invocation_handler +async def decorator_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Decorator response"], + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +# Script functions (lowest priority) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script", + "priority": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script", + "priority": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to specify highest priority handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_priority_ping" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_priority_invoke" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler priority + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable has highest priority and should be used + assert ping_data["priority"] == "environment_variable" + assert ping_data["source"] == "env_var" + + # Test invocation handler priority + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable has highest priority and should be used + assert invoke_data["priority"] == "environment_variable" + assert invoke_data["source"] == "env_var" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py new file mode 100644 index 000000000000..a2867efdc584 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import openai # use the official async_client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import MODEL_NAME_SMOLLM + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # The SageMaker standards library creates a POST /adapters endpoint + # that maps to the load_lora_adapter handler with request shape: + # {"lora_name": "body.name", "lora_path": "body.src"} + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "smollm2-lora-sagemaker", "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + models = await async_client.models.list() + models = models.data + dynamic_lora_model = models[-1] + assert dynamic_lora_model.root == smollm2_lora_files + assert dynamic_lora_model.parent == MODEL_NAME_SMOLLM + assert dynamic_lora_model.id == "smollm2-lora-sagemaker" + + +@pytest.mark.asyncio +async def test_sagemaker_unload_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter + adapter_name = "smollm2-lora-sagemaker-unload" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify it's in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name in adapter_ids + + # Now unload it using DELETE /adapters/{adapter_name} + # The SageMaker standards maps this to unload_lora_adapter with: + # {"lora_name": "path_params.adapter_name"} + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify it's no longer in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name not in adapter_ids + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_not_found( + basic_server_with_lora: RemoteOpenAIServer, +): + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "nonexistent-adapter", "src": "/path/does/not/exist"}, + ) + assert load_response.status_code == 404 + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_invalid_files( + basic_server_with_lora: RemoteOpenAIServer, + tmp_path, +): + invalid_files = tmp_path / "invalid_adapter" + invalid_files.mkdir() + (invalid_files / "adapter_config.json").write_text("not valid json") + + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "invalid-adapter", "src": str(invalid_files)}, + ) + assert load_response.status_code == 400 + + +@pytest.mark.asyncio +async def test_sagemaker_unload_nonexistent_adapter( + basic_server_with_lora: RemoteOpenAIServer, +): + # Attempt to unload an adapter that doesn't exist + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", "nonexistent-adapter-name"), + ) + assert unload_response.status_code in (400, 404) + + +@pytest.mark.asyncio +async def test_sagemaker_invocations_with_adapter( + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter via SageMaker endpoint + adapter_name = "smollm2-lora-invoke-test" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Now test the /invocations endpoint with the adapter + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={ + "X-Amzn-SageMaker-Adapter-Identifier": adapter_name, + }, + json={ + "prompt": "Hello, how are you?", + "max_tokens": 10, + }, + ) + invocation_response.raise_for_status() + invocation_output = invocation_response.json() + + # Verify we got a valid completion response + assert "choices" in invocation_output + assert len(invocation_output["choices"]) > 0 + assert "text" in invocation_output["choices"][0] + + +@pytest.mark.asyncio +async def test_sagemaker_multiple_adapters_load_unload( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + adapter_names = [f"sagemaker-adapter-{i}" for i in range(5)] + + # Load all adapters + for adapter_name in adapter_names: + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify all are in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name in adapter_ids + + # Unload all adapters + for adapter_name in adapter_names: + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify all are removed from models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name not in adapter_ids diff --git a/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py new file mode 100644 index 000000000000..f1ed0c7e2897 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration test for middleware loader functionality. + +Tests that customer middlewares get called correctly with a vLLM server. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestMiddlewareIntegration: + """Integration test for middleware with vLLM server.""" + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + + def _clear_caches(self): + """Clear middleware registry and function loader cache.""" + try: + from model_hosting_container_standards.common.fastapi.middleware import ( + middleware_registry, + ) + from model_hosting_container_standards.common.fastapi.middleware.source.decorator_loader import ( # noqa: E501 + decorator_loader, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + middleware_registry.clear_middlewares() + decorator_loader.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + @pytest.mark.asyncio + async def test_customer_middleware_with_vllm_server(self): + """Test that customer middlewares work with actual vLLM server. + + Tests decorator-based middlewares (@custom_middleware, @input_formatter, + @output_formatter) + on multiple endpoints (chat/completions, invocations). + """ + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script with multiple decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware, input_formatter, output_formatter +) + +# Global flag to track if input formatter was called +_input_formatter_called = False + +@input_formatter +async def customer_input_formatter(request): + # Process input - mark that input formatter was called + global _input_formatter_called + _input_formatter_called = True + return request + +@custom_middleware("throttle") +async def customer_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Customer-Throttle"] = "applied" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "throttle," + return response + +@output_formatter +async def customer_output_formatter(response): + global _input_formatter_called + response.headers["X-Customer-Processed"] = "true" + # Since input_formatter and output_formatter are combined into + # pre_post_process middleware, + # if output_formatter is called, input_formatter should have been called too + if _input_formatter_called: + response.headers["X-Input-Formatter-Called"] = "true" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "output_formatter," + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to point to customer script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test 1: Middlewares applied to chat/completions endpoint + chat_response = requests.post( + server.url_for("v1/chat/completions"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert chat_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in chat_response.headers + assert chat_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in chat_response.headers + assert chat_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in chat_response.headers + assert chat_response.headers["X-Input-Formatter-Called"] == "true" + + # Verify middleware execution order + execution_order = chat_response.headers.get( + "X-Middleware-Order", "" + ).rstrip(",") + order_parts = execution_order.split(",") if execution_order else [] + assert "throttle" in order_parts + assert "output_formatter" in order_parts + + # Test 2: Middlewares applied to invocations endpoint + invocations_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert invocations_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in invocations_response.headers + assert invocations_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in invocations_response.headers + assert invocations_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in invocations_response.headers + assert ( + invocations_response.headers["X-Input-Formatter-Called"] == "true" + ) + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_with_ping_endpoint(self): + """Test that middlewares work with SageMaker ping endpoint.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware +) + +@custom_middleware("pre_post_process") +async def ping_tracking_middleware(request, call_next): + response = await call_next(request) + if request.url.path == "/ping": + response.headers["X-Ping-Tracked"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping endpoint with middleware + response = requests.get(server.url_for("ping")) + + assert response.status_code == 200 + assert "X-Ping-Tracked" in response.headers + assert response.headers["X-Ping-Tracked"] == "true" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_env_var_override(self): + """Test middleware environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with middleware functions specified via env vars + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +# Global flag to track if pre_process was called +_pre_process_called = False + +async def env_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Env-Throttle"] = "applied" + return response + +async def env_pre_process(request: Request) -> Request: + # Mark that pre_process was called + global _pre_process_called + _pre_process_called = True + return request + +async def env_post_process(response): + global _pre_process_called + if hasattr(response, 'headers'): + response.headers["X-Env-Post-Process"] = "applied" + # Since pre_process and post_process are combined into + # pre_post_process middleware, + # if post_process is called, pre_process should have been called too + if _pre_process_called: + response.headers["X-Pre-Process-Called"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables for middleware + # Use script_name with .py extension as per plugin example + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_MIDDLEWARE_THROTTLE: ( + f"{script_name}:env_throttle_middleware" + ), + FastAPIEnvVars.CUSTOM_PRE_PROCESS: f"{script_name}:env_pre_process", + FastAPIEnvVars.CUSTOM_POST_PROCESS: f"{script_name}:env_post_process", + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + response = requests.get(server.url_for("ping")) + assert response.status_code == 200 + + # Check if environment variable middleware was applied + headers = response.headers + + # Verify that env var middlewares were applied + assert "X-Env-Throttle" in headers, ( + "Throttle middleware should be applied via env var" + ) + assert headers["X-Env-Throttle"] == "applied" + + assert "X-Env-Post-Process" in headers, ( + "Post-process middleware should be applied via env var" + ) + assert headers["X-Env-Post-Process"] == "applied" + + # Verify that pre_process was called + assert "X-Pre-Process-Called" in headers, ( + "Pre-process should be called via env var" + ) + assert headers["X-Pre-Process-Called"] == "true" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py new file mode 100644 index 000000000000..6206000385bd --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import openai # use the official client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + HEADER_SAGEMAKER_CLOSED_SESSION_ID, + HEADER_SAGEMAKER_NEW_SESSION_ID, + HEADER_SAGEMAKER_SESSION_ID, + MODEL_NAME_SMOLLM, +) + +CLOSE_BADREQUEST_CASES = [ + ( + "nonexistent_session_id", + {"session_id": "nonexistent-session-id"}, + {}, + "session not found", + ), + ("malformed_close_request", {}, {"extra-field": "extra-field-data"}, None), +] + + +@pytest.mark.asyncio +async def test_create_session_badrequest(basic_server_with_lora: RemoteOpenAIServer): + bad_response = requests.post( + basic_server_with_lora.url_for("invocations"), + json={"requestType": "NEW_SESSION", "extra-field": "extra-field-data"}, + ) + + assert bad_response.status_code == 400 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_name,session_id_change,request_body_change,expected_error", + CLOSE_BADREQUEST_CASES, +) +async def test_close_session_badrequest( + basic_server_with_lora: RemoteOpenAIServer, + test_name: str, + session_id_change: dict[str, str], + request_body_change: dict[str, str], + expected_error: str | None, +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + if request_body_change: + close_request_json.update(request_body_change) + bad_session_id = session_id_change.get("session_id") + bad_close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: bad_session_id or valid_session_id}, + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert bad_close_response.status_code == 400 + if expected_error: + assert expected_error in bad_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_close_session_invalidrequest( + basic_server_with_lora: RemoteOpenAIServer, async_client: openai.AsyncOpenAI +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + invalid_close_response = requests.post( + url, + # no headers to specify session_id + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert invalid_close_response.status_code == 424 + assert "invalid session_id" in invalid_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_session(basic_server_with_lora: RemoteOpenAIServer): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + # test invocation with session id + + request_args = { + "model": MODEL_NAME_SMOLLM, + "prompt": "what is 1+1?", + "max_completion_tokens": 5, + "temperature": 0.0, + "logprobs": False, + } + + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json=request_args, + ) + invocation_response.raise_for_status() + + # close created session, should succeed + close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + close_response.raise_for_status() + + assert ( + close_response.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID) + == valid_session_id + ) diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/dynamic_lora.py new file mode 100644 index 000000000000..cc0f437e5c77 --- /dev/null +++ b/vllm/entrypoints/dynamic_lora.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import models, validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def register_dynamic_lora_routes(router: APIRouter): + @sagemaker_standards.register_load_adapter_handler( + request_shape={ + "lora_name": "body.name", + "lora_path": "body.src", + }, + ) + @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): + handler: OpenAIServingModels = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + @sagemaker_standards.register_unload_adapter_handler( + request_shape={ + "lora_name": "path_params.adapter_name", + } + ) + @router.post( + "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] + ) + async def unload_lora_adapter( + request: UnloadLoRAAdapterRequest, raw_request: Request + ): + handler: OpenAIServingModels = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + return router diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 51191879e478..fbb2d32a229d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -19,6 +19,7 @@ from http import HTTPStatus from typing import Annotated, Any, Literal +import model_hosting_container_standards.sagemaker as sagemaker_standards import prometheus_client import pydantic import regex as re @@ -65,7 +66,6 @@ ErrorInfo, ErrorResponse, IOProcessorResponse, - LoadLoRAAdapterRequest, PoolingBytesResponse, PoolingRequest, PoolingResponse, @@ -82,7 +82,6 @@ TranscriptionResponse, TranslationRequest, TranslationResponse, - UnloadLoRAAdapterRequest, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_classification import ServingClassification @@ -387,13 +386,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.get("/ping", response_class=Response) -@router.post("/ping", response_class=Response) -async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - @router.post( "/tokenize", dependencies=[Depends(validate_json_request)], @@ -1236,47 +1228,6 @@ async def is_scaling_elastic_ep(raw_request: Request): ] -@router.post( - "/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def invocations(raw_request: Request): - """For SageMaker, routes requests based on the request type.""" - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}" - ) from e - - valid_endpoints = [ - (validator, endpoint) - for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None - ] - - for request_validator, endpoint in valid_endpoints: - try: - request = request_validator.validate_python(body) - except pydantic.ValidationError: - continue - - return await endpoint(request, raw_request) - - type_names = [ - t.__name__ if isinstance(t := validator._type, type) else str(t) - for validator, _ in valid_endpoints - ] - msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" - res = base(raw_request).create_error_response(message=msg) - return JSONResponse(content=res.model_dump(), status_code=res.error.code) - - if envs.VLLM_TORCH_PROFILER_DIR: logger.warning_once( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -1304,39 +1255,6 @@ async def stop_profile(raw_request: Request): return Response(status_code=200) -if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - - @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): - handler = models(raw_request) - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - @router.post( - "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] - ) - async def unload_lora_adapter( - request: UnloadLoRAAdapterRequest, raw_request: Request - ): - handler = models(raw_request) - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1606,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes + + register_dynamic_lora_routes(router) + + from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes + + register_sagemaker_routes(router) + app.include_router(router) app.root_path = args.root_path @@ -1696,6 +1628,8 @@ async def log_response(request: Request, call_next): f"Invalid middleware {middleware}. Must be a function or a class." ) + app = sagemaker_standards.bootstrap(app) + return app diff --git a/vllm/entrypoints/sagemaker/__init__.py b/vllm/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..c1767137e4ea --- /dev/null +++ b/vllm/entrypoints/sagemaker/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""SageMaker-specific integration for vLLM.""" diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py new file mode 100644 index 000000000000..498b7294f0d8 --- /dev/null +++ b/vllm/entrypoints/sagemaker/routes.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from http import HTTPStatus + +import model_hosting_container_standards.sagemaker as sagemaker_standards +import pydantic +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import ( + INVOCATION_VALIDATORS, + base, + health, + validate_json_request, +) +from vllm.entrypoints.openai.protocol import ErrorResponse + + +def register_sagemaker_routes(router: APIRouter): + @router.post("/ping", response_class=Response) + @router.get("/ping", response_class=Response) + @sagemaker_standards.register_ping_handler + async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, + ) + @sagemaker_standards.register_invocation_handler + @sagemaker_standards.stateful_session_manager() + @sagemaker_standards.inject_adapter_id(adapter_path="model") + async def invocations(raw_request: Request): + """For SageMaker, routes requests based on the request type.""" + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] + + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue + + return await endpoint(request, raw_request) + + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.error.code) + + return router From e605e8e3233f895340f46665f93ab37b307491aa Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Tue, 11 Nov 2025 00:59:08 -0500 Subject: [PATCH 42/49] [Bugfix] Fix Stream Sync for Shared Expert Overlap (#28430) Signed-off-by: Vadim Gimpelson Signed-off-by: Robert Shaw Co-authored-by: Vadim Gimpelson --- .../gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml | 3 -- vllm/model_executor/layers/fused_moe/layer.py | 45 +++++++------------ 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml index ea9c95158405..9297bf6ddf2d 100644 --- a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -3,6 +3,3 @@ accuracy_threshold: 0.45 num_questions: 1319 num_fewshot: 5 max_model_len: 4096 -# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220 -env: - VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 27ad9c8fd1c2..39547cc83c7b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2456,28 +2456,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) - # If there are shared experts but we are not using a modular kernel, - # the shared experts must be called here - if has_separate_shared_experts: - assert self.shared_experts is not None - - if self.shared_experts_stream is not None: - # For chunked, we start the shared experts stream here - # (Note that no concurrency with the router/gate) - self.shared_experts_stream.wait_stream(current_stream()) - - with torch.cuda.stream(self.shared_experts_stream): - # Note that staged_hidden_states clone() is necessary - # here to avoid conflict with the main stream - shared_output = self.shared_experts( - staged_hidden_states.clone() - ) - else: - shared_output = self.shared_experts(staged_hidden_states) - - else: - shared_output = None - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -2506,11 +2484,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None - - # Here we finish the shared experts stream - if self.shared_experts_stream is not None: - current_stream().wait_stream(self.shared_experts_stream) - + shared_output = self.shared_experts(staged_hidden_states) final_hidden_states = ( shared_output, final_hidden_states, @@ -2619,11 +2593,22 @@ def forward_impl( assert self.shared_experts is not None if self.shared_experts_stream is not None: + # Clone BEFORE switching streams to avoid race condition + # where routed_expert kernel may mutate hidden_states. + hidden_states_clone = hidden_states.clone() + self.shared_experts_stream.wait_stream(current_stream()) + # Run shared experts in parallel on a separate stream with torch.cuda.stream(self.shared_experts_stream): - # Note that hidden_states clone() is necessary here to avoid - # conflict with the main stream - shared_output = self.shared_experts(hidden_states.clone()) + shared_output = self.shared_experts(hidden_states_clone) + + # Record that the clone will be used by shared_experts_stream + # to avoid gc issue from deallocation of hidden_states_clone + # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 + # NOTE: we dont need shared_output.record_stream(current_stream()) + # because we synch the streams before using shared_output. + hidden_states_clone.record_stream(self.shared_experts_stream) + else: shared_output = self.shared_experts(hidden_states) else: From a7adbc6c6b4bcdef5cfffdcd06edf86fcbfb7c69 Mon Sep 17 00:00:00 2001 From: iAmir97 <71513472+iAmir97@users.noreply.github.com> Date: Tue, 11 Nov 2025 13:44:35 +0700 Subject: [PATCH 43/49] [Doc] Sleep mode documentation (#28357) Signed-off-by: Amir Balwel Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: Amir Balwel Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/features/sleep_mode.md | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md index e7dd9fee12d3..edcbaa716447 100644 --- a/docs/features/sleep_mode.md +++ b/docs/features/sleep_mode.md @@ -13,6 +13,9 @@ Key benefits: !!! note This feature is only supported on CUDA platform. +!!! note + For more information, see this [Blog Post](https://blog.vllm.ai/2025/10/26/sleep-mode.html). + ## Sleep levels Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update. @@ -31,6 +34,7 @@ llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True) #### Python API ```python +# Sleep level 1 # Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache) llm.sleep(level=1) @@ -38,6 +42,21 @@ llm.sleep(level=1) llm.wake_up() ``` +```python +# Sleep level 2 +# Put the engine to sleep (level=2: discard both weights and KV cache) +llm.sleep(level=2) + +# Reallocate weights memory only +llm.wake_up(tags=["weights"]) + +# Load weights in-place +llm.collective_rpc("reload_weights") + +# Reallocate KV cache +llm.wake_up(tags=["kv_cache"]) +``` + #### RLHF weight updates During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations. @@ -69,10 +88,30 @@ VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ --port 8000 ``` +Below is an example of how to sleep and wake up a model in level 1. + +```bash +curl -X POST 'http://localhost:8000/sleep?level=1' +curl -X POST 'http://localhost:8000/wake_up' +``` + +And this is an example of how to sleep and wake up a model in level 2. + +```bash +curl -X POST 'http://localhost:8000/sleep?level=2' +# Reallocate weights memory only +curl -X POST 'http://localhost:8000/wake_up?tags=weights' +# Load weights in-place +curl -X POST 'http://localhost:8000/collective_rpc' -H 'Content-Type: application/json' -d '{"method":"reload_weights"}' +# Reallocate KV cache +curl -X POST 'http://localhost:8000/wake_up?tags=kv_cache' +``` + #### HTTP endpoints - `POST /sleep?level=1` — Put the model to sleep (`level=1`). - `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`). +- `POST /collective_rpc` — Perform a collective remote procedure call (RPC). - `GET /is_sleeping` — Check if the model is sleeping. !!! note From cc079763c59adb8c03305663a5b8857ba85deb1b Mon Sep 17 00:00:00 2001 From: David Ben-David Date: Tue, 11 Nov 2025 09:39:36 +0200 Subject: [PATCH 44/49] [BugFix] Avoid calling KV connector layer APIs when metadata is unset (#28253) Signed-off-by: David Ben-David Co-authored-by: David Ben-David Co-authored-by: Mark McLoughlin --- vllm/attention/layer.py | 4 ++++ vllm/distributed/kv_transfer/kv_connector/v1/base.py | 9 ++++++++- .../kv_transfer/kv_connector/v1/multi_connector.py | 6 ++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 96272981692c..acab0529f352 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -837,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str): return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -854,6 +856,8 @@ def maybe_save_kv_layer_to_connector( return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 354aa9a87183..f85eb414b222 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -204,11 +204,18 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: Returns: ConnectorMetadata: the connector metadata. """ - # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + return self._connector_metadata is not None + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d7bbf02c8367..c9d08e9b78ed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -171,16 +171,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. + # + # Note: Call the base class method to ensure metadata is also set on the + # MultiConnector instance itself; otherwise, `has_connector_metadata()` will + # always return False. def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) + super().bind_connector_metadata(connector_metadata) def clear_connector_metadata(self) -> None: for c in self._connectors: c.clear_connector_metadata() + super().clear_connector_metadata() def shutdown(self): exception: Exception | None = None From 4fd4b743a23cc6ccbd832f11be12317a8c2f0fbc Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 11 Nov 2025 00:07:24 -0800 Subject: [PATCH 45/49] [Bugfix] Fix max image size for PaddleOCR-VL (#28442) Signed-off-by: Roger Wang --- vllm/model_executor/models/paddleocr_vl.py | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 631475c964c0..12ae15699e7d 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -198,23 +198,18 @@ def get_num_image_tokens( if image_processor is None: image_processor = self.get_image_processor() - do_resize = True hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size - - if do_resize: - resized_height, resized_width = smart_resize( - height=image_height, - width=image_width, - factor=patch_size * merge_size, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, - ) - preprocessed_size = ImageSize(width=resized_width, height=resized_height) - else: - preprocessed_size = ImageSize(width=image_width, height=image_height) + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) grid_t = 1 grid_h = preprocessed_size.height // patch_size @@ -227,8 +222,19 @@ def get_num_image_tokens( def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() - image_size = hf_config.vision_config.image_size - return ImageSize(height=image_size, width=image_size) + + # See `smart_resize` for the calculation of the image size. + merge_size = hf_config.vision_config.spatial_merge_size + patch_size = hf_config.vision_config.patch_size + factor = merge_size * patch_size + max_num_tokens = self.get_image_processor().max_pixels // (factor**2) + # Find factors of max_num_tokens close to its square root + # to create a dummy image with a reasonable aspect ratio. + h_patches = int(math.sqrt(max_num_tokens)) + while max_num_tokens % h_patches != 0: + h_patches -= 1 + w_patches = max_num_tokens // h_patches + return ImageSize(height=h_patches * factor, width=w_patches * factor) class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]): From 798c7bebca5e3ea48b947af4cc7904a4507ba873 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 11 Nov 2025 00:19:51 -0800 Subject: [PATCH 46/49] [EPLB] Refactor balance_packing to use numpy and optimize GPU-CPU transfers in EPLB (#28369) Signed-off-by: Sage Moore --- vllm/distributed/eplb/rebalance_algo.py | 40 +++++++++++++++------- vllm/distributed/eplb/rebalance_execute.py | 14 +++++--- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index c9d30d6481ab..e6645e524cc3 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -12,6 +12,7 @@ on how the EPLB algorithm works. """ +import numpy as np import torch @@ -34,29 +35,44 @@ def balanced_packing( assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs + device = weight.device + if groups_per_pack == 1: pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=weight.device + weight.size(-1), dtype=torch.int64, device=device ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") - rank_in_pack = torch.full_like(pack_index, fill_value=-1) + weight_np = weight.cpu().numpy() + + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) + + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm for i in range(num_layers): - pack_weights = [0] * num_packs + pack_weights = [0.0] * num_packs pack_items = [0] * num_packs - for group in indices[i]: + + for group in indices_np[i]: + # Find a pack with capacity that has the lowest weight pack = min( - (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, ) + assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 + + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + return pack_index, rank_in_pack @@ -212,7 +228,7 @@ def rebalance_experts( replicas for each logical expert """ num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() + weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8ec3e956401..5c1efbaf03ba 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace( ) return + old_global_expert_indices_cpu = old_global_expert_indices.cpu() + new_global_expert_indices_cpu = new_global_expert_indices.cpu() + + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + for layer in range(num_moe_layers): - # NOTE(bowen): We need this synchronize to run, but I don't know why. - # If you figure out the reason, please let me know -- thank you! - torch.cuda.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, - old_global_expert_indices[layer].tolist(), - new_global_expert_indices[layer].tolist(), + old_global_expert_indices_cpu[layer].tolist(), + new_global_expert_indices_cpu[layer].tolist(), expert_weights[layer], expert_weights_buffer, ep_group, From f0359fffa434a4fce981389f9dff93a2a4c2b13e Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Tue, 11 Nov 2025 16:24:28 +0800 Subject: [PATCH 47/49] [Bugfix] fix qwen3-next crash (#28202) Signed-off-by: zjy0516 --- vllm/model_executor/models/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index aa7de5aa5f29..ddb8693c16e2 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -587,7 +587,7 @@ def _forward_core( self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_decodes + : attn_metadata.num_actual_tokens ], validate_data=True, ) From c7991269dd8fe86096a3eee5040e855801ae9665 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Nov 2025 16:45:38 +0800 Subject: [PATCH 48/49] [BugFix] 'DeepseekV2Config' object has no attribute 'use_mla'` (#28387) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/kimi_vl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index b54f53931d71..b79bdf8595ca 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -456,7 +456,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - if not config.use_mla: + use_mha = ( + config.model_type == "deepseek" + or config.qk_nope_head_dim + config.qk_rope_head_dim == 0 + ) + if use_mha: stacked_params_mapping += [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), From 319abd5ee9c50b25a7929ba1e3e6588d44fc9d6d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 12 Nov 2025 18:25:16 +0000 Subject: [PATCH 49/49] Remove dynamic shape Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 41 ++++++--- vllm/compilation/backends.py | 107 +++++++---------------- vllm/compilation/collective_fusion.py | 10 +-- vllm/compilation/compiler_interface.py | 36 ++++---- vllm/compilation/inductor_pass.py | 8 +- vllm/compilation/pass_manager.py | 7 +- vllm/compilation/piecewise_backend.py | 2 +- vllm/compilation/sequence_parallelism.py | 8 +- vllm/config/compilation.py | 10 +-- vllm/config/utils.py | 6 +- vllm/config/vllm.py | 3 +- 11 files changed, 105 insertions(+), 133 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index bacaa48ae477..b15f90395c6a 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -42,9 +42,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): with set_forward_context({}, vllm_config=vllm_config): - model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + model(torch.randn(BATCH_SIZE, MLP_SIZE)) for batch_size in batch_sizes: - model(torch.randn(batch_size, MLP_SIZE).cuda()) + model(torch.randn(batch_size, MLP_SIZE)) class PostGradPassManagerCheckRanges(InductorPass): @@ -70,11 +70,14 @@ def uuid(self) -> str: def test_compile_ranges(): post_grad_pass_manager = PostGradPassManagerCheckRanges( [ - Range(start=1, end=8), - Range(start=8, end=32), - Range(start=32, end=8193), + Range(start=1, end=9), + Range(start=16, end=16), + Range(start=9, end=33), + Range(start=64, end=64), + Range(start=33, end=8193), ] ) + torch.set_default_device("cuda") vllm_config = VllmConfig( scheduler_config=SchedulerConfig( max_num_batched_tokens=8192, @@ -82,6 +85,7 @@ def test_compile_ranges(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], + compile_sizes=[16, 64, 128], inductor_compile_config={ "post_grad_custom_post_pass": post_grad_pass_manager, # Disable inductor cache to get the number of passes correctly @@ -91,14 +95,31 @@ def test_compile_ranges(): ) with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() - batch_sizes = [1, 4, 16, 24, 48, 64] + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # Number of compilations: 3 for each compile range + 2 compile sizes + batch_sizes = [1, 4, 16, 24, 48, 64, 8192] # A has support_torch_compile with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, - num_backend_compilations=3, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_backend_compilations=5, ): run_model(vllm_config, model, batch_sizes) - assert post_grad_pass_manager.num_calls == 3 + assert post_grad_pass_manager.num_calls == 5 + + +def test_compile_config_get_compile_ranges(): + compilation_config = CompilationConfig( + compile_ranges_split_points=[8, 32], + ) + VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=compilation_config, + ) + assert compilation_config.get_compile_ranges() == [ + Range(start=1, end=9), + Range(start=9, end=33), + Range(start=33, end=8193), + ] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index efd68a71c7e4..b1fe58d08265 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -84,7 +84,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[Range | None, int, str], Any] = dict() + self.cache: dict[tuple[Range, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -93,7 +93,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, compile_range: Range | None = None): + def compile_context(self, compile_range: Range): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" @@ -153,7 +153,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None @@ -161,23 +161,13 @@ def load( compiled_graph = self.compiler.load( handle, graph, example_inputs, graph_index, compile_range ) - if compile_range is None: - logger.debug( - "Directly load the %s-th graph for dynamic compile range" - "from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Directly load the %s-th graph for compile range %s" - "from %s via handle %s", - graph_index, - str(compile_range), - self.compiler.name, - handle, - ) + logger.debug( + "Directly load the %s-th graph for compile range %sfrom %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) return compiled_graph def compile( @@ -186,9 +176,9 @@ def compile( example_inputs, additional_inductor_config, compilation_config: CompilationConfig, + compile_range: Range, graph_index: int = 0, num_graphs: int = 1, - compile_range: Range | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -208,19 +198,12 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if compile_range is None: - logger.info( - "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", - elapsed, - ) - else: - logger.info( - "Directly load the compiled graph(s) for compile range %s " - "from the cache, took %.3f s", - str(compile_range), - elapsed, - ) + logger.info( + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -230,10 +213,7 @@ def compile( maybe_key = None else: maybe_key = "artifact_compile_range_" - if compile_range is None: - maybe_key += "dynamic_shape" - else: - maybe_key += f"{compile_range.start}_{compile_range.end}" + maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( @@ -253,50 +233,29 @@ def compile( self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if compile_range is None: - logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) - else: - logger.info_once( - "Cache the graph of compile range %s for later use", - str(compile_range), - ) - if compile_range is None: - logger.debug( - "Store the %s-th graph for dynamic compile range" - "from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Store the %s-th graph for compile range%s from %s via handle %s", - graph_index, + logger.info_once( + "Cache the graph of compile range %s for later use", str(compile_range), - self.compiler.name, - handle, ) + logger.debug( + "Store the %s-th graph for compile range%s from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if compile_range is None: - logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", - elapsed, - scope="local", - ) - else: - logger.info_once( - "Compiling a graph for compile range %s takes %.2f s", - str(compile_range), - elapsed, - scope="local", - ) + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 32d1f1531f4c..bef8925661cd 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -432,7 +432,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: Range | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -442,9 +442,7 @@ def is_applicable_for_range(self, compile_range: Range | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return compile_range is not None and ( - compile_range.is_single_size() and compile_range.end % tp_size == 0 - ) + return compile_range.is_single_size() and compile_range.end % tp_size == 0 @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -1189,9 +1187,7 @@ def register_patterns(self): self.disabled = False - def is_applicable_for_range(self, compile_range: Range | None) -> bool: - if compile_range is None: - return False + def is_applicable_for_range(self, compile_range: Range) -> bool: return compile_range.end - 1 <= self.max_token_num @VllmInductorPass.time_and_log diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b95067aba191..3bafba2e1642 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -64,16 +64,15 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a range. If the `compile_range` is None, it means - the `example_inputs` have a dynamic shape. Otherwise, the - `compile_range` specifies the range of the inputs, - it could be concrete size, e.g. (4, 4). - Right now we only support one variable range of shapes for all inputs, + with a range. The `compile_range` specifies the range of the inputs, + it could be concrete size (if compile_sizes is provided), e.g. [4, 4) + or a range [4, 5). + Right now we only support one variable in ranges for all inputs, which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -100,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable: """ Load the compiled function from the handle. @@ -214,7 +213,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -224,13 +223,10 @@ def compile( set_inductor_config(current_config, compile_range) set_functorch_config() - if compile_range is not None: - if compile_range.is_single_size(): - dynamic_shapes = "from_example_inputs" - else: - dynamic_shapes = "from_graph" + if compile_range.is_single_size(): + dynamic_shapes = "from_example_inputs" else: - dynamic_shapes = "from_tracing_context" + dynamic_shapes = "from_graph" from torch._inductor import standalone_compile @@ -255,7 +251,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -319,7 +315,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -516,7 +512,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -612,8 +608,8 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, compile_range): - if compile_range is not None and compile_range.is_single_size(): +def set_inductor_config(config, compile_range: Range): + if compile_range.is_single_size(): # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE @@ -634,7 +630,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 008eba4629a3..8159b817f637 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -29,8 +29,8 @@ class PassContext: - def __init__(self, compile_range: Range | None): - self.compile_range: Range | None = compile_range + def __init__(self, compile_range: Range): + self.compile_range: Range = compile_range def get_pass_context() -> PassContext: @@ -40,7 +40,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(compile_range: Range | None): +def pass_context(compile_range: Range): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -97,7 +97,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: Range | None): + def is_applicable_for_range(self, compile_range: Range): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 820fa9b007e3..399c998d87f8 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -128,9 +128,8 @@ def uuid(self): state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) compile_range = get_pass_context().compile_range - if compile_range is not None: - # Include the compile range in the uuid to ensure that inductor - # recompiles the graph for the new dynamic compile range. - state["compile_range"] = str(compile_range) + # Include the compile range in the uuid to ensure that inductor + # recompiles the graph for the new dynamic compile range. + state["compile_range"] = str(compile_range) return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 8f34aa818a80..b59cc50f70bc 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -133,9 +133,9 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: args, self.compilation_config.inductor_compile_config, self.compilation_config, + compile_range=range_entry.compile_range, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - compile_range=range_entry.compile_range, ) # finished compilations for all required shapes diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 6a5ee5a0efb7..84484756e7ef 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -483,7 +483,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: Range | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -503,11 +503,7 @@ def is_applicable_for_range(self, compile_range: Range | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return ( - compile_range is not None - and (compile_range.is_single_size()) - and (compile_range.end % tp_size == 0) - ) + return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 36bbd2b9abff..85118544117d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -349,11 +349,11 @@ class CompilationConfig: compile_ranges_split_points: list[int] | None = None """Split points that represent compile ranges for inductor. The compile ranges are - [1, split_points[0]), - [split_points[0], split_points[1]), ..., - [split_points[-1], max_num_batched_tokens + 1). - Compile sizes are also used single element ranges: - [compile_sizes[i], compile_sizes[i] + 1). + [1, split_points[0] + 1), + [split_points[0] + 1, split_points[1] + 1), ..., + [split_points[-1] + 1, max_num_batched_tokens + 1). + Compile sizes are also used single element ranges, + the range is represented as [compile_sizes[i], compile_sizes[i] + 1). """ inductor_compile_config: dict = field(default_factory=dict) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index ea97ddf125f7..20304696ffcc 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -206,7 +206,11 @@ def __hash__(self) -> int: return hash((self.start, self.end)) def __str__(self) -> str: - return f"(start={self.start}, end={self.end})" + return ( + f"[{self.start}, {self.end + 1})" + if self.is_single_size() + else f"[{self.start}, {self.end})" + ) def __repr__(self) -> str: return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4557e59a5cf8..2d71bec7c517 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -965,12 +965,13 @@ def _set_compile_ranges(self): for x in compilation_config.compile_ranges_split_points: assert isinstance(x, int) assert x > 0, f"Invalid compile range split point: {x}" + # Split points need to be inclusive of the end so we add 1. if ( max_num_batched_tokens is not None and x < max_num_batched_tokens and x > 1 ): - computed_compile_ranges_split_points.append(x) + computed_compile_ranges_split_points.append(x + 1) compilation_config.compile_ranges_split_points = sorted( computed_compile_ranges_split_points ) # type: ignore