diff --git a/benchmarks/compile/benchmark_inductor_compiled_artifacts.py b/benchmarks/compile/benchmark_inductor_compiled_artifacts.py new file mode 100644 index 000000000000..8130a7cd8333 --- /dev/null +++ b/benchmarks/compile/benchmark_inductor_compiled_artifacts.py @@ -0,0 +1,659 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark script to measure inductor cache performance across multiple models. + +This script benchmarks the vLLM inductor cache feature by running models twice: +1. Cold Start: First run compiles from scratch and saves artifacts to cache +2. Warm Start: Second run loads precompiled artifacts from cache + +The script measures load time and inference performance for both runs and computes +the speedup achieved by using cached artifacts. + +Usage: + # Run all models with default settings + python benchmark_inductor_compiled_artifacts.py + + # Run specific models by index + python benchmark_inductor_compiled_artifacts.py --models 0 3 4 + + # Run specific models by name (partial match) + python benchmark_inductor_compiled_artifacts.py --models Qwen Llama + + # List available models + python benchmark_inductor_compiled_artifacts.py --list-models + + # Customize cache directory and output + python benchmark_inductor_compiled_artifacts.py \ + --cache-dir /path/to/cache --output results.json + + # Adjust generation parameters + python benchmark_inductor_compiled_artifacts.py --max-tokens 256 --batch-size 4 + +Requirements: + - PyTorch 2.10.0+ (for standalone_compile with serialization support) + - vLLM with VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS feature enabled + - Sufficient GPU memory (8x H100 recommended for largest models) + +Environment Variables: + VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS=1 (automatically set by script) + VLLM_USE_AOT_COMPILE=1 (automatically set by script) + VLLM_USE_STANDALONE_COMPILE=1 (automatically set by script) + VLLM_CACHE_ROOT (set via --cache-dir argument) + +NOTE: FLASHINFER_WORKSPACE_BUFFER_SIZE needs to be set to 512MB for mistral... + https://github.com/vllm-project/vllm/issues/25342 should fix. +""" + +import argparse +import json +import os +import shutil +import time +from dataclasses import dataclass + + +@dataclass +class ModelConfig: + """Configuration for a model to benchmark.""" + + name: str + tensor_parallel_size: int + + def __str__(self): + return f"{self.name} (TP={self.tensor_parallel_size})" + + +# Model configurations with appropriate TP sizes +DEFAULT_MODELS = [ + ModelConfig("Qwen/Qwen3-32B", 1), + ModelConfig("deepseek-ai/DeepSeek-R1-0528", 8), # Very large model, needs 8 GPUs + ModelConfig("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", 1), + ModelConfig("meta-llama/Llama-3.3-70B-Instruct", 2), + ModelConfig("nvidia/Llama-3.3-70B-Instruct-FP8", 2), + ModelConfig("mistralai/Mistral-Large-Instruct-2411", 2), +] + + +def run_vllm_inference( + model: str, + tensor_parallel_size: int, + prompt: str, + max_tokens: int, + batch_size: int, + enable_compile: bool, + cache_dir: str, + use_cached: bool = False, + use_baseline: bool = False, +) -> dict[str, float]: + """Run vLLM inference and measure timing. + + Args: + model: Model name or path + tensor_parallel_size: Number of GPUs for tensor parallelism + prompt: Input prompt + max_tokens: Maximum tokens to generate + batch_size: Batch size + enable_compile: Whether to enable torch compile + cache_dir: Cache directory for compiled artifacts + use_cached: Whether to use cached artifacts + (only applies if enable_compile=True) + use_baseline: If True, use standalone compile WITHOUT inductor cache + (baseline comparison) + + Returns: + Dictionary with timing information + """ + # Set environment variables FIRST, before any imports + # IMPORTANT: Keep all environment variables IDENTICAL between cold and + # warm start runs. Changes to env vars can affect the config hash and + # prevent warm starts + + # Clear any existing VLLM_FORCE_AOT_LOAD that might prevent compilation + os.environ.pop("VLLM_FORCE_AOT_LOAD", None) + + if enable_compile: + os.environ["VLLM_CACHE_ROOT"] = cache_dir + + if use_baseline: + # Baseline: Use standalone compile WITHOUT inductor cache + # This is the current/old approach for comparison + os.environ.pop("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", None) + os.environ["VLLM_USE_AOT_COMPILE"] = "1" + os.environ["VLLM_USE_STANDALONE_COMPILE"] = "1" + print( + "Environment: Using baseline " + "(standalone compile WITHOUT inductor cache)" + ) + else: + # New approach: Use inductor cache + # Always set all compilation flags the same way for both cold + # and warm starts. Cold/warm start is determined by whether + # artifacts exist, not by flags. Keeping flags identical ensures + # the config hash remains the same + os.environ["VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS"] = "1" + os.environ["VLLM_USE_AOT_COMPILE"] = "1" + os.environ["VLLM_USE_STANDALONE_COMPILE"] = "1" + print("Environment: Using inductor cache backend") + else: + # Eager mode: clear all compile-related flags + os.environ.pop("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", None) + os.environ.pop("VLLM_USE_AOT_COMPILE", None) + os.environ.pop("VLLM_USE_STANDALONE_COMPILE", None) + + # Now import after environment variables are set + import torch + + from vllm import LLM, SamplingParams + + # Create sampling params + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=max_tokens, + ) + + # Prepare prompts + prompts = [prompt] * batch_size + + # Measure model loading time + load_start = time.perf_counter() + + if enable_compile: + # Enable VLLM_COMPILE mode (level 3) which enables piecewise compilation + # with splitting at attention ops for multiple submodules + llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=False, + compilation_config={ + "level": 3, # CompilationMode.VLLM_COMPILE enables splitting + "backend": "inductor", + }, + # Increase from default 0.9 to use more GPU memory + gpu_memory_utilization=0.95, + # Limit context length to reduce memory usage + max_model_len=2048, + ) + else: + llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=True, + gpu_memory_utilization=0.95, + max_model_len=2048, + ) + load_time = time.perf_counter() - load_start + + # Warmup run (not timed) + print("Running warmup...") + _ = llm.generate(prompts[:1], sampling_params) + + # Clear GPU memory before timed run + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Timed inference run + print("Running timed inference...") + inference_start = time.perf_counter() + outputs = llm.generate(prompts, sampling_params) + torch.cuda.synchronize() + inference_time = time.perf_counter() - inference_start + + # Calculate tokens generated + total_tokens = sum(len(output.outputs[0].token_ids) for output in outputs) + tokens_per_second = total_tokens / inference_time + + return { + "load_time": load_time, + "inference_time": inference_time, + "total_tokens": total_tokens, + "tokens_per_second": tokens_per_second, + } + + +def format_time(seconds: float) -> str: + """Format time in seconds to a human-readable string.""" + if seconds < 1: + return f"{seconds * 1000:.2f}ms" + elif seconds < 60: + return f"{seconds:.2f}s" + else: + minutes = int(seconds // 60) + secs = seconds % 60 + return f"{minutes}m {secs:.2f}s" + + +def print_model_results(model_name: str, results: dict[str, float], run_type: str = ""): + """Print results for a single model.""" + prefix = f" [{run_type}] " if run_type else " " + print(f"\n{prefix}Load Time: {format_time(results['load_time'])}") + print(f"{prefix}Inference Time: {format_time(results['inference_time'])}") + print(f"{prefix}Total Tokens: {results['total_tokens']:.0f}") + print(f"{prefix}Tokens/Second: {results['tokens_per_second']:.2f}") + total_time = format_time(results["load_time"] + results["inference_time"]) + print(f"{prefix}Total Time: {total_time}") + + +def print_summary(all_results: dict[str, dict[str, any]]): + """Print summary table of all results.""" + print("\n" + "=" * 180) + print("INDUCTOR CACHE BENCHMARK - Comparing Load Time Performance") + print("=" * 180) + print("\nComparison: Old Approach (Baseline AOT) vs New Approach (Inductor Cache)") + print("=" * 180) + + # Table header - make it crystal clear what we're comparing + header = ( + f"\n{'Model':<45} {'OLD: Baseline':<15} {'NEW: Inductor':<15} " + f"{'BENEFIT':<12} {'Time Saved':<12} {'Miss Penalty':<15}" + ) + print(header) + subheader = ( + f"{'(Approach being compared)':<45} {'AOT Cache':<15} " + f"{'Cache Hit':<15} {'(Speedup)':<12} {'(Seconds)':<12} " + f"{'vs Baseline':<15}" + ) + print(subheader) + print("-" * 180) + + # Sort by model name for consistent output + for model_name in sorted(all_results.keys()): + results = all_results[model_name] + + # Check if we have all results + if "cold_start" in results and "warm_start" in results: + miss = results["cold_start"] + hit = results["warm_start"] + + hit_load = format_time(hit["load_time"]) + + # Check if baseline_hit exists (for apples-to-apples comparison) + if "baseline_hit" in results: + base_hit = results["baseline_hit"] + base_load = format_time(base_hit["load_time"]) + + # PRIMARY METRIC: Inductor cache hit vs baseline hit + speedup = base_hit["load_time"] / hit["load_time"] + speedup_str = f"{speedup:.2f}x" + + # Time saved in seconds + time_saved = base_hit["load_time"] - hit["load_time"] + if time_saved >= 0: + time_saved_str = f"-{format_time(time_saved)}" + else: + time_saved_str = f"+{format_time(abs(time_saved))}" + + # Miss penalty: compare cache miss overhead vs baseline miss + if "baseline_miss" in results: + base_miss = results["baseline_miss"] + miss_penalty_pct = ( + (miss["load_time"] - base_miss["load_time"]) + / base_miss["load_time"] + ) * 100 + miss_penalty_str = f"{miss_penalty_pct:+.1f}%" + else: + miss_penalty_str = "N/A" + else: + base_load = "N/A" + speedup_str = "N/A" + time_saved_str = "N/A" + miss_penalty_str = "N/A" + + # Truncate model name if too long + model_display = ( + model_name if len(model_name) <= 45 else model_name[:42] + "..." + ) + + row = ( + f"{model_display:<45} {base_load:<15} {hit_load:<15} " + f"{speedup_str:<12} {time_saved_str:<12} {miss_penalty_str:<15}" + ) + print(row) + else: + # Only have partial results + model_display = ( + model_name if len(model_name) <= 45 else model_name[:42] + "..." + ) + n_a_row = ( + f"{model_display:<45} {'N/A':<15} {'N/A':<15} " + f"{'N/A':<12} {'N/A':<12} {'N/A':<15}" + ) + print(n_a_row) + + print("=" * 180) + print("\nKey Findings:") + print( + " 1. BENEFIT (Speedup): NEW approach (inductor cache) vs " + "OLD approach (baseline AOT)" + ) + print(" - Shows how much faster the new inductor cache is at loading") + print(" - Example: 1.34x means inductor cache is 34% faster") + print("\n 2. Time Saved: Absolute time reduction from using inductor cache") + print(" - Direct time savings per model load") + print( + "\n 3. Miss Penalty: Overhead of first compile with inductor " + "cache serialization" + ) + print( + " - Negative % means even first compile is faster than " + "baseline (excellent!)" + ) + print( + " - Positive % means first compile is slower (overhead from serialization)" + ) + print("=" * 180) + + +def clean_all_caches(cache_dir: str): + """Clean all compilation caches before running benchmarks. + + Removes: + 1. torch_compile_cache: {cache_dir}/torch_compile_cache/ (Dynamo cache) + 2. torch_aot_compile cache: {cache_dir}/torch_aot_compile/ (AOT artifacts) + 3. vllm compile cache: /tmp/vllm_compile_cache_* (temp cache) + """ + import glob + + print("\n" + "=" * 80) + print("CLEANING ALL CACHES") + print("=" * 80) + + # Clean torch_compile_cache (Dynamo cache - most important!) + torch_compile_path = os.path.join(cache_dir, "torch_compile_cache") + if os.path.exists(torch_compile_path): + print(f"Removing: {torch_compile_path}") + shutil.rmtree(torch_compile_path) + else: + print(f"Not found: {torch_compile_path}") + + # Clean torch_aot_compile cache + torch_aot_path = os.path.join(cache_dir, "torch_aot_compile") + if os.path.exists(torch_aot_path): + print(f"Removing: {torch_aot_path}") + shutil.rmtree(torch_aot_path) + else: + print(f"Not found: {torch_aot_path}") + + # Clean vllm_compile_cache_* in /tmp + vllm_cache_pattern = "/tmp/vllm_compile_cache_*" + matching_dirs = glob.glob(vllm_cache_pattern) + if matching_dirs: + for cache_path in matching_dirs: + print(f"Removing: {cache_path}") + shutil.rmtree(cache_path) + else: + print(f"No matching directories found for pattern: {vllm_cache_pattern}") + + print("=" * 80) + print("CACHE CLEANING COMPLETE") + print("=" * 80 + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark inductor precompile across multiple models" + ) + parser.add_argument( + "--models", + type=str, + nargs="+", + help=( + "Specific models to benchmark (by name or index). " + "If not specified, runs all models." + ), + ) + parser.add_argument( + "--list-models", + action="store_true", + help="List available models and exit", + ) + parser.add_argument( + "--prompt", + type=str, + default="Tell me a story about a robot learning to paint.", + help="Input prompt for generation", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=128, + help="Maximum tokens to generate (default: 128)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size for inference (default: 1)", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=f"/data/users/{os.getenv('USER', 'unknown')}/vllm_cache", + help="Cache directory for compiled artifacts", + ) + parser.add_argument( + "--output", + type=str, + help="Output JSON file for results", + ) + parser.add_argument( + "--include-baseline", + action="store_true", + help=( + "Include baseline run with standalone compile " + "(without inductor cache) for comparison" + ), + ) + + args = parser.parse_args() + + # List models if requested + if args.list_models: + print("Available models:") + for i, model_config in enumerate(DEFAULT_MODELS): + print(f" {i}: {model_config}") + return + + # Determine which models to run + models_to_run = [] + if args.models: + for model_spec in args.models: + # Check if it's an index + try: + idx = int(model_spec) + if 0 <= idx < len(DEFAULT_MODELS): + models_to_run.append(DEFAULT_MODELS[idx]) + else: + max_idx = len(DEFAULT_MODELS) - 1 + print(f"ERROR: Model index {idx} out of range (0-{max_idx})") + return + except ValueError: + # It's a model name - find matching config + found = False + for config in DEFAULT_MODELS: + if model_spec in config.name: + models_to_run.append(config) + found = True + break + if not found: + # Create custom config with TP=1 + models_to_run.append(ModelConfig(model_spec, 1)) + else: + models_to_run = DEFAULT_MODELS + + # Setup cache directory + cache_dir = args.cache_dir + os.makedirs(cache_dir, exist_ok=True) + print(f"Using cache directory: {cache_dir}") + + all_results = {} + + # Run benchmark for each model + for model_config in models_to_run: + print("\n" + "=" * 100) + print(f"Benchmarking: {model_config}") + print("=" * 100) + + model_results = {} + + try: + # Clean ALL caches before cold start run + clean_all_caches(cache_dir) + + # Run 1: Cold Start (compile from scratch) + print("\n--- RUN 1: COLD START (compiling from scratch) ---") + + cold_start_results = run_vllm_inference( + model=model_config.name, + tensor_parallel_size=model_config.tensor_parallel_size, + prompt=args.prompt, + max_tokens=args.max_tokens, + batch_size=args.batch_size, + enable_compile=True, + cache_dir=cache_dir, + use_cached=False, + ) + model_results["cold_start"] = cold_start_results + print_model_results(model_config.name, cold_start_results, "Cold Start") + + # Clean up GPU memory between runs + import gc + + import torch + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("\nCleaned up GPU memory before cache hit run") + + # Run 2: Warm Start (load from cache) + print("\n--- RUN 2: WARM START (loading from cache) ---") + warm_start_results = run_vllm_inference( + model=model_config.name, + tensor_parallel_size=model_config.tensor_parallel_size, + prompt=args.prompt, + max_tokens=args.max_tokens, + batch_size=args.batch_size, + enable_compile=True, + cache_dir=cache_dir, + use_cached=True, + ) + model_results["warm_start"] = warm_start_results + print_model_results(model_config.name, warm_start_results, "Warm Start") + + # Runs 3 & 4 (optional): Baseline with standalone AOT compile + # (no inductor cache) + if args.include_baseline: + # Clean ALL caches to ensure baseline starts fresh + clean_all_caches(cache_dir) + + # Clean up GPU memory + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("Cleaned up GPU memory before baseline runs") + + # Baseline Run 1: Cold Start + print("\n--- RUN 3: BASELINE COLD START (first AOT compile) ---") + note = ( + "Note: This uses AOT compile but WITHOUT the inductor " + "cache backend feature" + ) + print(note) + baseline_miss_results = run_vllm_inference( + model=model_config.name, + tensor_parallel_size=model_config.tensor_parallel_size, + prompt=args.prompt, + max_tokens=args.max_tokens, + batch_size=args.batch_size, + enable_compile=True, + cache_dir=cache_dir, + use_cached=False, + use_baseline=True, + ) + model_results["baseline_miss"] = baseline_miss_results + print_model_results( + model_config.name, baseline_miss_results, "Baseline Miss" + ) + + # Clean up GPU memory between baseline runs + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("\nCleaned up GPU memory before baseline warm start run") + + # Baseline Run 2: Warm Start + print("\n--- RUN 4: BASELINE WARM START (loading from AOT cache) ---") + baseline_hit_results = run_vllm_inference( + model=model_config.name, + tensor_parallel_size=model_config.tensor_parallel_size, + prompt=args.prompt, + max_tokens=args.max_tokens, + batch_size=args.batch_size, + enable_compile=True, + cache_dir=cache_dir, + use_cached=True, # Use cached AOT artifacts + use_baseline=True, + ) + model_results["baseline_hit"] = baseline_hit_results + print_model_results( + model_config.name, baseline_hit_results, "Baseline Hit" + ) + + # Calculate and display speedup + load_speedup = ( + cold_start_results["load_time"] / warm_start_results["load_time"] + ) + total_speedup = ( + cold_start_results["load_time"] + cold_start_results["inference_time"] + ) / (warm_start_results["load_time"] + warm_start_results["inference_time"]) + print("\n Speedup (Warm Start vs Cold Start):") + print(f" Load Time: {load_speedup:.2f}x faster with warm start") + print(f" Total Time: {total_speedup:.2f}x faster with warm start") + + if args.include_baseline and "baseline_hit" in model_results: + baseline_hit_to_hit_speedup = ( + baseline_hit_results["load_time"] / warm_start_results["load_time"] + ) + print("\n Speedup (Inductor Warm Start vs Baseline Hit):") + speedup_msg = ( + f" Load Time: {baseline_hit_to_hit_speedup:.2f}x - " + "inductor cache vs baseline AOT cache" + ) + print(speedup_msg) + + all_results[model_config.name] = model_results + + except Exception as e: + print(f"ERROR: Failed to benchmark {model_config.name}: {e}") + import traceback + + traceback.print_exc() + continue + finally: + # Always clean up GPU memory after each model + import gc + + import torch + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("Cleaned up GPU memory") + + # Print summary + if all_results: + print_summary(all_results) + + # Save to JSON if requested + if args.output: + with open(args.output, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {args.output}") + else: + print("\nNo successful benchmarks to report.") + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index c65e5a25934d..faa4c3ea8a45 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib +import pickle import tempfile from contextlib import contextmanager +from unittest.mock import Mock, patch import pytest import torch -from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.caching import VllmSerializableFunction +from vllm.compilation.decorators import save_compile_cache, support_torch_compile from vllm.config import ( CompilationConfig, CompilationMode, @@ -39,6 +43,7 @@ def make_vllm_config() -> VllmConfig: return VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, + backend="inductor", ) ) @@ -59,6 +64,8 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): expected = reference_fn(*args) with use_vllm_config(vllm_config): m.setenv("VLLM_USE_AOT_COMPILE", "0") + m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1") + m.setenv("VLLM_USE_STANDALONE_COMPILE", "1") with ( pytest.raises(RuntimeError, match="Detected recompile"), torch.compiler.set_stance("fail_on_recompile"), @@ -79,6 +86,8 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m: args = (torch.randn(10, 10),) m.setenv("VLLM_USE_AOT_COMPILE", "1") + m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1") + m.setenv("VLLM_USE_STANDALONE_COMPILE", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_CACHE_ROOT", tmpdirname) vllm_config = make_vllm_config() @@ -96,9 +105,13 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch): with tempfile.TemporaryDirectory() as tmpdirname: m.setenv("VLLM_CACHE_ROOT", tmpdirname) m.setenv("VLLM_USE_AOT_COMPILE", "1") + m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1") + m.setenv("VLLM_USE_STANDALONE_COMPILE", "1") vllm_config = make_vllm_config() with use_vllm_config(vllm_config): - expected = CompiledMod(vllm_config=vllm_config)(*args) + compiled_mod = CompiledMod(vllm_config=vllm_config) + expected = compiled_mod(*args) + save_compile_cache(compiled_mod) m.setenv("VLLM_FORCE_AOT_LOAD", "1") vllm_config = make_vllm_config() @@ -121,6 +134,8 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): with tempfile.TemporaryDirectory() as tmpdirname: m.setenv("VLLM_CACHE_ROOT", tmpdirname) m.setenv("VLLM_USE_AOT_COMPILE", "1") + m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1") + m.setenv("VLLM_USE_STANDALONE_COMPILE", "1") vllm_config = make_vllm_config() with use_vllm_config(vllm_config): compiled_mod = CompiledMod(vllm_config=vllm_config) @@ -128,6 +143,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): artifacts = compiled_mod.aot_compiled_fn._artifacts guards_string = artifacts.compiled_fn.shape_env.format_guards() assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + save_compile_cache(compiled_mod) m.setenv("VLLM_FORCE_AOT_LOAD", "1") vllm_config = make_vllm_config() @@ -137,3 +153,227 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): artifacts = compiled_mod.aot_compiled_fn._artifacts guards_string = artifacts.compiled_fn.shape_env.format_guards() assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + + +class TestInductorCompiledArtifacts: + def test_init(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + assert cache.submodule_bytes == {} + assert cache.submodule_bytes_store == {} + assert cache.loaded_submodule_store == {} + + def test_insert_new_artifact(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + test_data = b"test_artifact_data" + submod_name = "test_submod" + shape = "s1" + + hasher = hashlib.sha256() + hasher.update(test_data) + expected_hash = hasher.hexdigest() + + cache.insert(submod_name, shape, test_data) + + assert f"{submod_name}_{shape}" in cache.submodule_bytes + assert cache.submodule_bytes[f"{submod_name}_{shape}"] == expected_hash + assert expected_hash in cache.submodule_bytes_store + assert cache.submodule_bytes_store[expected_hash] == test_data + + def test_insert_duplicate_artifact(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + test_data = b"duplicate_test_data" + submod_name1 = "submod1" + submod_name2 = "submod2" + shape = "s2" + + cache.insert(submod_name1, shape, test_data) + cache.insert(submod_name2, shape, test_data) + + hash1 = cache.submodule_bytes[f"{submod_name1}_{shape}"] + hash2 = cache.submodule_bytes[f"{submod_name2}_{shape}"] + assert hash1 == hash2 + + assert len(cache.submodule_bytes_store) == 1 + assert len(cache.submodule_bytes) == 2 + + def test_get_artifact(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + test_data = b"retrievable_data" + submod_name = "mod1" + shape = "shape16" + + cache.insert(submod_name, shape, test_data) + retrieved_data = cache.get(submod_name, shape) + + assert retrieved_data == test_data + + def test_get_nonexistent_artifact(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + with pytest.raises(KeyError): + cache.get("nonexistent", "shape") + + def test_size_bytes(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + assert cache.size_bytes() == 0 + + data1 = b"x" * 100 + data2 = b"y" * 200 + cache.insert("mod1", "shape1", data1) + cache.insert("mod2", "shape2", data2) + + assert cache.size_bytes() == 300 + + def test_num_artifacts_and_entries(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + assert cache.num_artifacts() == 0 + assert cache.num_entries() == 0 + + cache.insert("mod1", "shape1", b"data1") + cache.insert("mod2", "shape2", b"data2") + assert cache.num_artifacts() == 2 + assert cache.num_entries() == 2 + + cache.insert("mod3", "shape3", b"data1") + assert cache.num_artifacts() == 2 + assert cache.num_entries() == 3 + + @patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize") + def test_load_all_success(self, mock_deserialize): + """Test successful loading of all artifacts""" + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + mock_artifact1 = Mock() + mock_artifact2 = Mock() + mock_deserialize.side_effect = [mock_artifact1, mock_artifact2] + + cache.insert("mod1", "shape1", pickle.dumps(b"data1")) + cache.insert("mod2", "shape2", pickle.dumps(b"data2")) + + cache.load_all() + + assert len(cache.loaded_submodule_store) == 2 + assert mock_deserialize.call_count == 2 + + @patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize") + def test_load_all_already_loaded(self, mock_deserialize): + """Test that load_all skips if already loaded""" + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + mock_artifact = Mock() + cache.submodule_bytes_store["hash1"] = pickle.dumps(b"data1") + cache.loaded_submodule_store["hash1"] = mock_artifact + + cache.load_all() + + mock_deserialize.assert_not_called() + + @patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize") + def test_get_loaded_artifact(self, mock_deserialize): + """Test retrieving loaded artifacts""" + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + mock_artifact = Mock() + mock_deserialize.return_value = mock_artifact + + submod_name = "test_mod" + shape = "test_shape" + cache.insert(submod_name, shape, pickle.dumps(b"test_data")) + cache.load_all() + + retrieved_artifact = cache.get_loaded(submod_name, shape) + assert retrieved_artifact == mock_artifact + + def test_getstate_setstate(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + cache.insert("mod1", "shape1", b"data1") + cache.insert("mod2", "shape2", b"data2") + + cache.loaded_submodule_store["hash1"] = Mock() + + state = cache.__getstate__() + + assert "submodule_bytes" in state + assert "submodule_bytes_store" in state + assert "loaded_submodule_store" not in state + + new_cache = VllmSerializableFunction.InductorCompiledArtifacts() + new_cache.__setstate__(state) + + assert new_cache.submodule_bytes == cache.submodule_bytes + assert new_cache.submodule_bytes_store == cache.submodule_bytes_store + assert new_cache.loaded_submodule_store == {} + + def test_pickle_roundtrip(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + test_data1 = b"pickle_test_data_1" + test_data2 = b"pickle_test_data_2" + cache.insert("mod1", "shape1", test_data1) + cache.insert("mod2", "shape2", test_data2) + + pickled_data = pickle.dumps(cache) + restored_cache = pickle.loads(pickled_data) + + assert restored_cache.get("mod1", "shape1") == test_data1 + assert restored_cache.get("mod2", "shape2") == test_data2 + assert restored_cache.num_artifacts() == cache.num_artifacts() + assert restored_cache.num_entries() == cache.num_entries() + assert restored_cache.size_bytes() == cache.size_bytes() + + assert len(restored_cache.loaded_submodule_store) == 0 + + +class TestInductorCompiledArtifactsIntegration: + def test_add_pickle_unpickle(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + artifacts = { + ("mod1", "shape1"): b"m1s1_artifact", + ("mod1", "shape2"): b"m1s2_artifact", + ("mod2", "shape1"): b"m2s1_artifact", + ("mod2", "shape2"): b"m2s2_artifact", + } + + for (submod, shape), data in artifacts.items(): + cache.insert(submod, shape, data) + + assert cache.num_entries() == 4 + assert cache.num_artifacts() == 4 + + for (submod, shape), expected_data in artifacts.items(): + retrieved_data = cache.get(submod, shape) + assert retrieved_data == expected_data + + pickled = pickle.dumps(cache) + restored_cache = pickle.loads(pickled) + + for (submod, shape), expected_data in artifacts.items(): + retrieved_data = restored_cache.get(submod, shape) + assert retrieved_data == expected_data + + def test_deduplication(self): + cache = VllmSerializableFunction.InductorCompiledArtifacts() + + shared_data = b"shared_artifact_data" * 1000 + + cache.insert("mod1", "shape1", shared_data) + cache.insert("mod2", "shape1", shared_data) + cache.insert("mod1", "shape2", shared_data) + cache.insert("mod3", "shape3", shared_data) + + assert cache.num_entries() == 4 + assert cache.num_artifacts() == 1 + assert cache.size_bytes() == len(shared_data) + + for submod, shape in [ + ("mod1", "shape1"), + ("mod2", "shape1"), + ("mod1", "shape2"), + ("mod3", "shape3"), + ]: + assert cache.get(submod, shape) == shared_data diff --git a/tests/compile/test_backend_with_cache.py b/tests/compile/test_backend_with_cache.py new file mode 100644 index 000000000000..86c89d6df6b9 --- /dev/null +++ b/tests/compile/test_backend_with_cache.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Tests for the VllmBackendWithCache class. +""" + +import pytest +import torch + +from vllm.compilation.backends import VllmBackendWithCache +from vllm.compilation.caching import VllmSerializableFunction +from vllm.config import CompilationConfig, CompilationMode, VllmConfig +from vllm.config.vllm import get_current_vllm_config + +InductorCompiledArtifacts = VllmSerializableFunction.InductorCompiledArtifacts + + +def make_vllm_config() -> VllmConfig: + """Create a test VllmConfig.""" + return VllmConfig( + compilation_config=CompilationConfig( + level=CompilationMode.VLLM_COMPILE, + ) + ) + + +class TestVllmBackendWithCache: + """Test the VllmBackendWithCache class.""" + + def test_init(self): + """Test initialization of VllmBackendWithCache.""" + inductor_compiled_artifacts = ( + VllmSerializableFunction.InductorCompiledArtifacts() + ) + vllm_config = make_vllm_config() + + backend = VllmBackendWithCache( + inductor_compiled_artifacts=inductor_compiled_artifacts, + vllm_config=vllm_config, + prefix="test", + submod_names=["submod_0", "submod_1"], + ) + + assert backend.inductor_compiled_artifacts == inductor_compiled_artifacts + assert backend.vllm_config == vllm_config + assert backend.prefix == "test" + assert backend.submod_names == ["submod_0", "submod_1"] + assert backend.compiled_callables == {"submod_0": {}, "submod_1": {}} + + def test_build_dispatch_callable_empty_cache(self): + """Test building dispatch callable with an empty cache.""" + inductor_compiled_artifacts = ( + VllmSerializableFunction.InductorCompiledArtifacts() + ) + vllm_config = make_vllm_config() + + backend = VllmBackendWithCache( + inductor_compiled_artifacts=inductor_compiled_artifacts, + vllm_config=vllm_config, + prefix="test", + submod_names=["submod_0"], + ) + + # With empty cache, compiled_callables should be empty + assert backend.compiled_callables == {"submod_0": {}} + + def test_create_piecewise_backend_from_cache_no_general_shape(self): + """Test creating piecewise backend without a general shape function.""" + inductor_compiled_artifacts = ( + VllmSerializableFunction.InductorCompiledArtifacts() + ) + vllm_config = make_vllm_config() + + backend = VllmBackendWithCache( + inductor_compiled_artifacts=inductor_compiled_artifacts, + vllm_config=vllm_config, + prefix="test", + submod_names=["submod_0"], + ) + + # Should raise ValueError when no general shape function is available + with pytest.raises( + ValueError, match="No general shape compiled function found" + ): + backend.create_piecewise_backend_from_cache("submod_0", 0) + + def test_call_empty_submodules(self): + """Test creating split_gm with empty submod_names.""" + backend = VllmBackendWithCache( + inductor_compiled_artifacts=InductorCompiledArtifacts(), + vllm_config=get_current_vllm_config(), + prefix="test", + submod_names=[], + sym_shape_indices_map={}, + returns_tuple_map={}, + ) + + # Create a simple split_gm + import torch.fx as fx + + gm = fx.GraphModule(torch.nn.Module(), fx.Graph()) + + # Should work with empty submodules (no replacements needed) + result_gm = backend.create_split_gm_from_cache(gm) + assert result_gm is gm + + +class TestVllmBackendWithCacheFlag: + """Test the VllmBackendWithCache flag integration.""" + + def test_flag_parsing(self): + """Test that the flag is properly parsed.""" + import os + + # Test default value + if "VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS" in os.environ: + del os.environ["VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS"] + import vllm.envs as envs + + assert not envs.use_backend_with_cache() + + # Test enabling the flag + os.environ["VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS"] = "1" + # Need to reload the function to get the new value + assert os.environ["VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS"] == "1" + + # Clean up + del os.environ["VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS"] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +class TestVllmBackendWithCacheIntegration: + """Integration tests for VllmBackendWithCache.""" + + def test_full_workflow_with_mock_cache(self): + """Test the full workflow with a mocked inductor cache.""" + # This is a placeholder for a more comprehensive integration test + # that would actually populate the cache and test execution + pass + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index b96a6701333d..d7b9b665d278 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -21,6 +21,7 @@ "vllm/transformers_utils/config.py", "vllm/model_executor/models/registry.py", "vllm/compilation/caching.py", + "vllm/compilation/piecewise_backend.py", "vllm/distributed/utils.py", "vllm/distributed/parallel_state.py", "vllm/distributed/device_communicators/all_reduce_utils.py", @@ -29,6 +30,7 @@ "vllm/utils/hashing.py", "tests/utils_/test_hashing.py", "tests/tokenization/test_cached_tokenizer.py", + "tests/compile/test_aot_compile.py", "benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/benchmark_lora.py", "benchmarks/kernels/benchmark_machete.py", diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 53fd5e74dc0a..649f3edb34df 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -27,7 +27,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import is_torch_equal_or_newer -from .caching import VllmSerializableFunction from .compiler_interface import ( CompilerInterface, EagerAdaptor, @@ -42,24 +41,52 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + # Check if inductor should be used by looking at the backend field if compilation_config.backend == "inductor": - # Use standalone compile only if requested, version is new enough, - # and the symbol actually exists in this PyTorch build. - if ( + # Use standalone compile if: + # 1. Explicitly requested via VLLM_USE_STANDALONE_COMPILE, OR + # 2. AOT compile is enabled (which requires serialization), OR + # 3. Backend with inductor cache is enabled (requires serialization) + # AND the PyTorch version is new enough and has standalone_compile + should_use_standalone = ( envs.VLLM_USE_STANDALONE_COMPILE + or envs.VLLM_USE_AOT_COMPILE + or envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS + ) + if ( + should_use_standalone and is_torch_equal_or_newer("2.8.0.dev") and hasattr(torch._inductor, "standalone_compile") ): + if ( + envs.VLLM_USE_AOT_COMPILE + or envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS + ) and not envs.VLLM_USE_STANDALONE_COMPILE: + msg = ( + "AOT compile or inductor cache backend is enabled, " + "automatically using InductorStandaloneAdaptor for " + "serialization support" + ) + logger.info(msg) logger.debug("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: + if ( + envs.VLLM_USE_AOT_COMPILE + or envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS + ): + msg = ( + "VLLM_USE_AOT_COMPILE or " + "VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS is " + "set but standalone compile is not available. These " + "features require PyTorch 2.8.0+ with standalone_compile " + "support. Falling back to InductorAdaptor without " + "serialization support." + ) + logger.warning(msg) logger.debug("Using InductorAdaptor") return InductorAdaptor() else: - assert compilation_config.backend == "eager", ( - "Custom backends not supported with CompilationMode.VLLM_COMPILE" - ) - logger.debug("Using EagerAdaptor") return EagerAdaptor() @@ -202,7 +229,6 @@ def compile( # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed if runtime_shape is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " @@ -354,6 +380,51 @@ def split_graph( compilation_start_time = 0.0 +def wrap_with_cudagraph_if_needed( + piecewise_backend: Any, + vllm_config: VllmConfig, + compilation_config: CompilationConfig, + is_first_graph: bool, + is_last_graph: bool, +) -> Any: + """ + Wrap a piecewise backend with CUDA graph wrapper if needed. + This function is shared between VllmBackend and VllmBackendWithCache. + + Args: + piecewise_backend: The backend to wrap + vllm_config: The vLLM configuration + compilation_config: The compilation configuration + is_first_graph: Whether this is the first graph in the sequence + is_last_graph: Whether this is the last graph in the sequence + + Returns: + The wrapped backend if CUDA graphs are enabled, otherwise the original backend + """ + if ( + compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not compilation_config.use_inductor_graph_partition + ): + from .cuda_graph import CUDAGraphOptions + + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls() + ) + + return static_graph_wrapper_class( + runnable=piecewise_backend, + vllm_config=vllm_config, + runtime_mode=CUDAGraphMode.PIECEWISE, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=is_first_graph, + gc_disable=not is_first_graph, + weak_ref_output=is_last_graph, + ), + ) + else: + return piecewise_backend + + class PiecewiseCompileInterpreter(torch.fx.Interpreter): """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. It runs the given graph with fake inputs, and compile some @@ -409,17 +480,19 @@ def call_module( ] global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None, + with torch._functorch.config.patch("bundled_autograd_cache", True): + compiled_graph_for_dynamic_shape = ( + self.vllm_backend.compiler_manager.compile( + submod, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None, + ) ) - ) + # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -433,36 +506,14 @@ def call_module( self.vllm_backend, ) - if ( - self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() - and not self.compilation_config.use_inductor_graph_partition - ): - # We're using Dynamo-based piecewise splitting, so we wrap - # the whole subgraph with a static graph wrapper. - from .cuda_graph import CUDAGraphOptions - - # resolve the static graph wrapper class (e.g. CUDAGraphWrapper - # class) as platform dependent. - static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls() - ) - - # Always assign PIECEWISE runtime mode to the - # CUDAGraphWrapper for piecewise_backend, to distinguish - # it from the FULL cudagraph runtime mode, no matter it - # is wrapped on a full or piecewise fx graph. - self.module.__dict__[target] = static_graph_wrapper_class( - runnable=piecewise_backend, - vllm_config=self.vllm_config, - runtime_mode=CUDAGraphMode.PIECEWISE, - cudagraph_options=CUDAGraphOptions( - debug_log_enable=piecewise_backend.is_first_graph, - gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph, - ), - ) - else: - self.module.__dict__[target] = piecewise_backend + # Use the shared cudagraph wrapper function + self.module.__dict__[target] = wrap_with_cudagraph_if_needed( + piecewise_backend, + self.vllm_config, + self.compilation_config, + piecewise_backend.is_first_graph, + piecewise_backend.is_last_graph, + ) compilation_counter.num_piecewise_capturable_graphs_seen += 1 @@ -489,9 +540,193 @@ def set_model_tag(tag: str): model_tag = old_tag +class VllmBackendWithCache: + """A backend that reconstructs the compiled model from cached inductor + artifacts. + + This backend takes the inductor cache directly and constructs a split_gm + object from scratch, without relying on VllmBackend's existing logic. + This avoids the overhead of saving the dynamo graph module and + re-splitting the graph module. + + The workflow is: + 1. Take the inductor cache with all compiled graph pieces + 2. Construct a callable that dispatches to the right compiled graphs + 3. Wrap with cudagraph if needed + """ + + def __init__( + self, + inductor_compiled_artifacts: Any, + vllm_config: VllmConfig, + prefix: str = "", + submod_names: list[str] | None = None, + sym_shape_indices_map: dict[str, list[int]] | None = None, + returns_tuple_map: dict[str, bool] | None = None, + ): + """ + Initialize the backend with an inductor cache. + + Args: + inductor_compiled_artifacts: The inductor cache containing + compiled artifacts + vllm_config: The vLLM configuration + prefix: The prefix for this backend (e.g., model_tag) + submod_names: List of submodule names in compilation order + sym_shape_indices_map: Mapping from submod_name to sym_shape_indices + returns_tuple_map: Mapping from submod_name to returns_tuple + """ + self.inductor_compiled_artifacts = inductor_compiled_artifacts + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.prefix = prefix + self.submod_names = submod_names or [] + self.sym_shape_indices_map = sym_shape_indices_map or {} + self.returns_tuple_map = returns_tuple_map or {} + + # Create a VllmBackend instance for PiecewiseBackend to use + # This is needed for compiler_manager access + self.vllm_backend = VllmBackend(vllm_config, prefix) + + # Initialize the compiler manager with cache disabled since + # we're loading from cache. We don't need to compile anything + # new, just need the save_to_file method to work. + # Use a dummy cache directory since we won't actually write. + dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache") + os.makedirs(dummy_cache_dir, exist_ok=True) + self.vllm_backend.compiler_manager.initialize_cache( + cache_dir=dummy_cache_dir, + disable_cache=True, + prefix=prefix, + ) + + # Load all artifacts from cache + self.inductor_compiled_artifacts.load_all() + + # Build the dispatch callable + self.build_dispatch_callable() + + def build_dispatch_callable(self): + """Build a callable that dispatches to the right compiled graphs.""" + # Store compiled callables for each submodule and shape + self.compiled_callables: dict[str, dict[str, Callable]] = {} + + for submod_name in self.submod_names: + self.compiled_callables[submod_name] = {} + + # For each submodule, we need to load the compiled graph + # for each shape. The cache stores entries as + # "{submod_name}_{shape}". We need to extract the general + # shape (None) and any specific shapes + for cache_key in self.inductor_compiled_artifacts.submodule_bytes: + if cache_key.startswith(f"{submod_name}_"): + shape_str = cache_key[len(submod_name) + 1 :] + compiled_fn = self.inductor_compiled_artifacts.get_loaded( + submod_name, shape_str + ) + self.compiled_callables[submod_name][shape_str] = compiled_fn + + def create_piecewise_backend_from_cache( + self, + submod_name: str, + index: int, + ): + """Create a piecewise backend from cached artifacts for a + specific submodule.""" + from .piecewise_backend import PiecewiseBackend + + # Get the compiled callable for the general shape + general_shape_fn = self.compiled_callables[submod_name].get("None") + if general_shape_fn is None: + raise ValueError( + f"No general shape compiled function found for {submod_name}" + ) + + # Create a lightweight piecewise backend that uses the cached artifacts + # We need to create a minimal graph module as a placeholder + # since PiecewiseBackend expects one + dummy_graph = fx.GraphModule({}, fx.Graph()) + + # Determine which shapes are available for this submodule + available_shapes = [ + shape_str + for shape_str in self.compiled_callables[submod_name] + if shape_str != "None" + ] + + def get_compiled_graph_for_size(shape_str: str): + """Get the compiled graph for a specific shape from cache.""" + return self.compiled_callables[submod_name].get(shape_str) + + # Get sym_shape_indices from the map. Default to [0] (batch dimension) + # if not found, which is the typical case in vLLM where the first + # argument represents the batch size (a symbolic shape). + sym_shape_indices = self.sym_shape_indices_map.get(submod_name, [0]) + + # Get returns_tuple from the map + returns_tuple = self.returns_tuple_map.get(submod_name) + + piecewise_backend = PiecewiseBackend( + graph=dummy_graph, + vllm_config=self.vllm_config, + piecewise_compile_index=index, + total_piecewise_compiles=len(self.submod_names), + sym_shape_indices=sym_shape_indices, + compiled_graph_for_general_shape=general_shape_fn, + vllm_backend=self.vllm_backend, + get_compiled_graph_for_size=( + get_compiled_graph_for_size if available_shapes else None + ), + returns_tuple=returns_tuple, + ) + + return piecewise_backend + + def create_split_gm_from_cache(self, split_gm: fx.GraphModule) -> fx.GraphModule: + """Replace the submodules in split_gm with piecewise backends + loaded from cache. + + This allows us to reuse the graph structure from split_gm while + loading the compiled artifacts from cache. + + Args: + split_gm: The split graph module from deserialization. This + contains the structure of how submodules are chained, but + the submodules themselves need to be replaced with piecewise + backends loaded from cache. + + Returns: + The modified split_gm with submodules replaced by piecewise + backends from cache. + """ + for i, submod_name in enumerate(self.submod_names): + # Create piecewise backend from cache + piecewise_backend = self.create_piecewise_backend_from_cache(submod_name, i) + + # Wrap with cudagraph if needed + is_first = i == 0 + is_last = i == len(self.submod_names) - 1 + wrapped_backend = wrap_with_cudagraph_if_needed( + piecewise_backend, + self.vllm_config, + self.compilation_config, + is_first, + is_last, + ) + + # Replace the submodule in split_gm + setattr(split_gm, submod_name, wrapped_backend) + logger.debug( + "Replaced submodule %s with piecewise backend from cache", + submod_name, + ) + + return split_gm + + class VllmBackend: """The compilation backend for `torch.compile` with vLLM. - It is used for compilation mode of `CompilationMode.VLLM_COMPILE`, + It is used for compilation level of `CompilationLevel.PIECEWISE`, where we customize the compilation. The major work of this backend is to split the graph into @@ -545,6 +780,120 @@ def __init__( # `torch.compile` is JIT compiled, so we don't need to # do anything here + def _collect_inductor_compiled_artifacts( + self, submod_names: list[str] + ) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]: + """Collect inductor cache artifacts from all piecewise backends. + + Returns: + tuple: (inductor_compiled_artifacts, sym_shape_indices_map, + returns_tuple_map) + - inductor_compiled_artifacts: InductorCompiledArtifacts + with compiled artifacts + - sym_shape_indices_map: dict mapping submod_name to + sym_shape_indices + - returns_tuple_map: dict mapping submod_name to + returns_tuple + """ + + if not envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS: + return None, None, None + + from torch._inductor.compile_fx import graph_returns_tuple + + from .caching import VllmSerializableFunction + + inductor_compiled_artifacts = ( + VllmSerializableFunction.InductorCompiledArtifacts() + ) + sym_shape_indices_map = {} + returns_tuple_map = {} + + for submod_name in submod_names: + # Get the piecewise backend from the split_gm + if not hasattr(self.split_gm, submod_name): + logger.warning( + "Submodule %s not found in split_gm, skipping cache collection", + submod_name, + ) + continue + + piecewise_backend = getattr(self.split_gm, submod_name) + + # If it's wrapped in a CUDA graph wrapper, unwrap it + if hasattr(piecewise_backend, "runnable"): + piecewise_backend = piecewise_backend.runnable + + # Collect sym_shape_indices from the piecewise backend + if hasattr(piecewise_backend, "sym_shape_indices"): + sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices + logger.debug( + "Collected sym_shape_indices for %s: %s", + submod_name, + piecewise_backend.sym_shape_indices, + ) + + # Collect returns_tuple information + if hasattr(piecewise_backend, "graph"): + returns_tuple = graph_returns_tuple(piecewise_backend.graph) + returns_tuple_map[submod_name] = returns_tuple + logger.debug( + "Collected returns_tuple for %s: %s", + submod_name, + returns_tuple, + ) + + has_serialize = hasattr( + piecewise_backend.compiled_graph_for_general_shape, "serialize" + ) + logger.debug( + "Piecewise backend for %s: has serialize=%s", + submod_name, + has_serialize, + ) + + if has_serialize: + bytes_dict = piecewise_backend.to_bytes() + if bytes_dict: + for shape_str, bytes_data in bytes_dict.items(): + inductor_compiled_artifacts.insert( + submod_name, shape_str, bytes_data + ) + logger.debug( + "Collected inductor cache for %s with shape %s (%d bytes)", + submod_name, + shape_str, + len(bytes_data), + ) + else: + logger.warning( + "Piecewise backend for %s returned empty to_bytes() - " + "bundled_autograd_cache may not have been enabled", + submod_name, + ) + else: + logger.debug( + "Compiled graph for %s does not support serialization " + "(missing 'serialize' method). aot " + "was not enabled during compilation. Check that " + "VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS=1", + submod_name, + ) + + logger.info( + "Collected inductor cache: %d entries, %d artifacts, %d bytes total", + inductor_compiled_artifacts.num_entries(), + inductor_compiled_artifacts.num_artifacts(), + inductor_compiled_artifacts.size_bytes(), + ) + + logger.info( + "Inductor cache keys: %s", + list(inductor_compiled_artifacts.submodule_bytes.keys()), + ) + + return inductor_compiled_artifacts, sym_shape_indices_map, returns_tuple_map + def configure_post_pass(self): config = self.compilation_config self.post_grad_pass_manager.configure(self.vllm_config) @@ -566,10 +915,12 @@ def configure_post_pass(self): self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager - def __call__( - self, graph: fx.GraphModule, example_inputs - ) -> VllmSerializableFunction: - from .caching import _compute_code_hash, compilation_config_hash_factors + def __call__(self, graph: fx.GraphModule, example_inputs): + from .caching import ( + VllmSerializableFunction, + _compute_code_hash, + compilation_config_hash_factors, + ) vllm_config = self.vllm_config if not self.compilation_config.cache_dir: @@ -582,7 +933,6 @@ def __call__( # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) code_hash = _compute_code_hash(self.compilation_config.traced_files) - self.compilation_config.traced_files.clear() factors.append(code_hash) # 3. compiler hash @@ -643,14 +993,16 @@ def __call__( self.graph = graph self.configure_post_pass() - if self.compilation_config.use_inductor_graph_partition: - # Let Inductor decide partitioning; avoid FX-level pre-splitting. - fx_split_ops: list[str] = [] - else: - fx_split_ops = self.compilation_config.splitting_ops or [] + # Resolve splitting ops from strings to OpOverload objects + from vllm.compilation.partition_rules import resolve_defined_ops - resolved_split_ops = resolve_defined_ops(fx_split_ops) - self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops) + resolved_splitting_ops = resolve_defined_ops( + self.compilation_config.splitting_ops or [] + ) + + self.split_gm, self.piecewise_graphs = split_graph( + graph, resolved_splitting_ops + ) from torch._dynamo.utils import lazy_format_graph_code @@ -674,8 +1026,7 @@ def __call__( graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): - # code adapted from - # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa # use `print_readable` because it can include submodules src = ( "from __future__ import annotations\nimport torch\n" @@ -691,12 +1042,41 @@ def __call__( self._called = True + # Extract shape_env from the graph module's fake mode if available + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode() + shape_env = fake_mode.shape_env if fake_mode else None + + # Extract submod_names from piecewise_graphs for serialization + submod_names = [ + item.submod_name + for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + # Collect inductor cache and sym_shape_indices from all + # piecewise backends + ( + inductor_compiled_artifacts, + sym_shape_indices_map, + returns_tuple_map, + ) = self._collect_inductor_compiled_artifacts(submod_names) + if ( self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or not self.compilation_config.cudagraph_copy_inputs ): return VllmSerializableFunction( - graph, example_inputs, self.prefix, self.split_gm + graph, + example_inputs, + self.prefix, + self.split_gm, + shape_env, + submod_names, + inductor_compiled_artifacts, + sym_shape_indices_map, + returns_tuple_map, ) # if we need to copy input buffers for cudagraph @@ -743,5 +1123,13 @@ def copy_and_call(*args): return self.split_gm(*list_args) return VllmSerializableFunction( - graph, example_inputs, self.prefix, copy_and_call + graph, + example_inputs, + self.prefix, + copy_and_call, + shape_env, + submod_names, + inductor_compiled_artifacts, + sym_shape_indices_map, + returns_tuple_map, ) diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 16e34c2711e9..111b45d2ffa4 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -11,6 +11,7 @@ from torch.utils import _pytree as pytree import vllm.envs as envs +from vllm.compilation.compiler_interface import get_inductor_factors from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger @@ -36,18 +37,135 @@ class VllmSerializableFunction(SerializableCallable): serializing the Dynamo fx graph plus example inputs. """ - def __init__(self, graph_module, example_inputs, prefix, optimized_call): + class InductorCompiledArtifacts: + def __init__(self): + # dict from submodule name to byte hash + self.submodule_bytes = {} + # dict from byte hash to bytes + self.submodule_bytes_store = {} + # dict from byte hash to loaded module + self.loaded_submodule_store = {} + + def insert(self, submod_name: str, shape: str, entry: bytes): + hasher = hashlib.sha256() + hasher.update(entry) + hex_digest = hasher.hexdigest() + self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest + if hex_digest not in self.submodule_bytes_store: + self.submodule_bytes_store[hex_digest] = entry + logger.debug( + "Using stored inductor cache artifact for submod %s " + "with shape %s (%s bytes) at hash %s", + submod_name, + shape, + len(entry), + hex_digest, + ) + else: + logger.debug( + "Inserting inductor artifact for submod %s with shape %s " + "(%s bytes) at hash %s", + submod_name, + shape, + len(entry), + hex_digest, + ) + + def get(self, submod_name: str, shape: str) -> bytes: + logger.debug( + "Getting inductor artifact for submod %s with shape %s", + submod_name, + shape, + ) + return self.submodule_bytes_store[ + self.submodule_bytes[f"{submod_name}_{shape}"] + ] + + def get_loaded(self, submod_name: str, shape: str): + logger.debug( + "Getting inductor artifact for submod %s with shape %s", + submod_name, + shape, + ) + return self.loaded_submodule_store[ + self.submodule_bytes[f"{submod_name}_{shape}"] + ] + + def size_bytes(self) -> int: + return sum(len(entry) for entry in self.submodule_bytes_store.values()) + + def num_artifacts(self) -> int: + return len(self.submodule_bytes_store) + + def num_entries(self) -> int: + return len(self.submodule_bytes) + + def load_all(self) -> None: + import concurrent.futures + + # check if already loaded + if len(self.loaded_submodule_store) == len(self.submodule_bytes_store): + return + + from torch._inductor.standalone_compile import AOTCompiledArtifact + + def _load_entry(entry_bytes) -> AOTCompiledArtifact: + # Unpickle the bundled cache entry first + entry = pickle.loads(entry_bytes) + return AOTCompiledArtifact.deserialize(entry) + + with concurrent.futures.ThreadPoolExecutor() as executor: + entries = list(self.submodule_bytes_store.values()) + loaded_entries = list(executor.map(_load_entry, entries)) + + for i, k in enumerate(self.submodule_bytes_store.keys()): + self.loaded_submodule_store[k] = loaded_entries[i] + + logger.debug("loaded all %s submodules", self.num_artifacts()) + + def __getstate__(self): + return { + "submodule_bytes": self.submodule_bytes, + "submodule_bytes_store": self.submodule_bytes_store, + } + + def __setstate__(self, state): + self.submodule_bytes = state["submodule_bytes"] + self.submodule_bytes_store = state["submodule_bytes_store"] + self.loaded_submodule_store = {} + + def __init__( + self, + graph_module, + example_inputs, + prefix, + optimized_call, + shape_env=None, + submod_names=None, + inductor_compiled_artifacts=None, + sym_shape_indices_map=None, + returns_tuple_map=None, + ): assert isinstance(graph_module, torch.fx.GraphModule) self.graph_module = graph_module self.example_inputs = example_inputs self.prefix = prefix self.optimized_call = optimized_call - self.shape_env = None - sym_input = next( - (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None - ) - if sym_input is not None: - self.shape_env = sym_input.node.shape_env + self.shape_env = shape_env + # Store submodule names for VllmBackendWithCache + self.submod_names = submod_names or [] + # Store inductor cache for serialization/deserialization + self.inductor_compiled_artifacts = inductor_compiled_artifacts + # Store sym_shape_indices for each submodule + self.sym_shape_indices_map = sym_shape_indices_map or {} + # Store returns_tuple for each submodule + self.returns_tuple_map = returns_tuple_map or {} + if shape_env is None: + sym_input = next( + (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None + ) + if sym_input is not None: + self.shape_env = sym_input.node.shape_env def __call__(self, *args, **kwargs): return self.optimized_call(*args, **kwargs) @@ -98,13 +216,102 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction from torch.fx._graph_pickler import GraphPickler from torch.fx.experimental.symbolic_shapes import ShapeEnv - from vllm.compilation.backends import VllmBackend - state = pickle.loads(data) fake_mode = FakeTensorMode(shape_env=ShapeEnv()) state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) - vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) + + # Check if we should use VllmBackendWithCache + if envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS: + from vllm.compilation.backends import VllmBackendWithCache + + # Check if inductor cache has actual artifacts + inductor_compiled_artifacts = state.get("inductor_compiled_artifacts") + submod_names = state.get("submod_names") + + has_artifacts = ( + inductor_compiled_artifacts is not None + and hasattr(inductor_compiled_artifacts, "num_artifacts") + and inductor_compiled_artifacts.num_artifacts() > 0 + ) + + num_artifacts = ( + inductor_compiled_artifacts.num_artifacts() + if inductor_compiled_artifacts + else 0 + ) + num_submods = len(submod_names) if submod_names else 0 + + logger.info( + "VllmBackendWithCache check: has_artifacts=%s, " + "num_artifacts=%d, num_submods=%d", + has_artifacts, + num_artifacts, + num_submods, + ) + + if not has_artifacts or not submod_names: + # Cache doesn't exist yet or is incomplete + # Fall back to standard compilation path which will populate it + logger.warning( + "VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS is set " + "but inductor cache is empty (artifacts=%d) or submod_names " + "is missing (len=%d). Falling back to standard compilation " + "path to populate cache.", + num_artifacts, + num_submods, + ) + # Continue to fallback path below instead of raising + else: + # Cache exists, use VllmBackendWithCache + logger.info( + "Loading from VllmBackendWithCache with %d artifacts " + "and %d submodules", + num_artifacts, + num_submods, + ) + + sym_shape_indices_map = state.get("sym_shape_indices_map", {}) + returns_tuple_map = state.get("returns_tuple_map", {}) + + vllm_backend_with_cache = VllmBackendWithCache( + inductor_compiled_artifacts=inductor_compiled_artifacts, + vllm_config=get_current_vllm_config(), + prefix=state["prefix"], + submod_names=submod_names, + sym_shape_indices_map=sym_shape_indices_map, + returns_tuple_map=returns_tuple_map, + ) + + # Get the split_gm from the deserialized state and populate it + # with piecewise backends from cache + split_gm = state["graph_module"] + + # Populate split_gm with piecewise backends from cache + # This replaces the submodules with cached piecewise backends + populated_split_gm = vllm_backend_with_cache.create_split_gm_from_cache( + split_gm + ) + + # Recompile the populated split_gm to generate its forward method. + populated_split_gm.recompile() + + def optimized_call_with_cache(*args, **kwargs): + # Execute the compiled forward method + return populated_split_gm(*args, **kwargs) + + fn: VllmSerializableFunction = cls( + **state, optimized_call=optimized_call_with_cache + ) + logger.info("Successfully created VllmBackendWithCache function") + return fn + + # Fall back to standard VllmBackend + from vllm.compilation.backends import VllmBackend + + vllm_backend: VllmBackend = VllmBackend( + get_current_vllm_config(), state["prefix"] + ) def optimized_call(*example_inputs): """ @@ -145,6 +352,11 @@ def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: # model is created) config_hash = vllm_config.compute_hash() factors.append(config_hash) + + # 2. inductor factors if applicable + if envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS: + factors.extend(get_inductor_factors()) + return factors diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db94..e273de68e006 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -12,12 +12,16 @@ import torch import torch._inductor.compile_fx import torch.fx as fx +from torch._inductor.standalone_compile import AOTCompiledArtifact import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.utils.torch_utils import is_torch_equal_or_newer +logger = init_logger(__name__) + class CompilerInterface: """ @@ -200,22 +204,50 @@ def compile( if compiler_config is not None: current_config.update(compiler_config) set_inductor_config(current_config, runtime_shape) - set_functorch_config() if isinstance(runtime_shape, int): dynamic_shapes = "from_example_inputs" else: dynamic_shapes = "from_tracing_context" + # Check if PyTorch version supports 'aot' parameter in standalone_compile + # This was added in PyTorch 2.10+ from torch._inductor import standalone_compile - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, + supports_aot = is_torch_equal_or_newer("2.10.0.dev") + + if not supports_aot and envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS: + logger.error( + "CRITICAL: VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS " + "is enabled but PyTorch version does not support 'aot' " + "parameter in standalone_compile. This requires PyTorch " + "2.10.0+. Falling back to non-AOT mode." + ) + + compile_kwargs = { + "dynamic_shapes": dynamic_shapes, + "options": { + "config_patches": current_config, + }, + } + + use_aot: bool = ( + supports_aot and envs.VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS ) + # Only add 'aot' parameter if both supported and enabled + if use_aot: + compile_kwargs["aot"] = True # type: ignore[assignment] + + compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) + + if use_aot: + assert isinstance(compiled_graph, AOTCompiledArtifact) + # just return the compiled graph and a key + # since we can serialize the bytes using to_bytes + # and reload it using the key when reading + return compiled_graph, None + # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) @@ -282,9 +314,11 @@ def initialize_cache( # set flags so that Inductor and Triton store their cache # in the cache_dir, then users only need to copy the cache_dir # to another machine to reuse the cache. - inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache") - os.makedirs(inductor_cache, exist_ok=True) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + inductor_compiled_artifacts = os.path.join( + self.base_cache_dir, "inductor_compiled_artifacts" + ) + os.makedirs(inductor_compiled_artifacts, exist_ok=True) + os.environ["TORCHINDUCTOR_COMPILED_ARTIFACTS_DIR"] = inductor_compiled_artifacts triton_cache = os.path.join(self.base_cache_dir, "triton_cache") os.makedirs(triton_cache, exist_ok=True) os.environ["TRITON_CACHE_DIR"] = triton_cache @@ -309,7 +343,6 @@ def compile( current_config["fx_graph_remote_cache"] = False set_inductor_config(current_config, runtime_shape) - set_functorch_config() # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -599,10 +632,6 @@ def set_inductor_config(config, runtime_shape): ) -def set_functorch_config(): - torch._functorch.config.bundled_autograd_cache = False - - class EagerAdaptor(CompilerInterface): name = "eager" diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 4a4903035cf9..1ce1fc1ace17 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -30,6 +30,42 @@ IGNORE_COMPILE_KEY = "_ignore_compile_vllm" + +def save_compile_cache(compiled_module: torch.nn.Module) -> None: + # Save the AOT compiled function artifacts + if hasattr(compiled_module, "save_aot_compiled_function"): + compiled_module.save_aot_compiled_function() + + # Save the compiler manager cache (vllm_compile_cache.py) + if not hasattr(compiled_module, "aot_compiled_fn"): + return + + artifacts = compiled_module.aot_compiled_fn._artifacts + compiled_fn = artifacts.compiled_fn + + # Access the split_gm's optimized_call if available + if hasattr(compiled_fn, "optimized_call"): + split_gm = compiled_fn.optimized_call + + # Find any piecewise backend with compiler_manager + for name in dir(split_gm): + if name.startswith("submod_"): + submod = getattr(split_gm, name, None) + if submod is None: + continue + + # Unwrap cudagraph wrapper if needed + if hasattr(submod, "runnable"): + submod = submod.runnable + + # Check if it has vllm_backend with compiler_manager + if hasattr(submod, "vllm_backend") and hasattr( + submod.vllm_backend, "compiler_manager" + ): + submod.vllm_backend.compiler_manager.save_to_file() + return + + _T = TypeVar("_T", bound=type[nn.Module]) @@ -246,6 +282,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): if self.do_not_compile: return + self.aot_compile_loaded_from_cache = False + compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_mode=vllm_config.compilation_config.mode @@ -303,6 +341,7 @@ def __call__(self, *args, **kwargs): loaded_fn = torch.compiler.load_compiled_function(f) _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) self.aot_compiled_fn = loaded_fn + self.aot_compile_loaded_from_cache = True except Exception as e: if os.path.exists(aot_compilation_path): logger.warning( @@ -400,10 +439,9 @@ def patched_inline_call(self_): if envs.VLLM_USE_AOT_COMPILE: self.aot_compiled_fn = self.aot_compile(*args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs) - assert aot_compilation_path is not None - assert cache_dir is not None - os.makedirs(cache_dir, exist_ok=True) - self.aot_compiled_fn.save_compiled_function(aot_compilation_path) + # Store the path for later saving after warmup + self._aot_compilation_path = aot_compilation_path + self._aot_cache_dir = cache_dir else: output = self.compiled_callable(*args, **kwargs) return output @@ -415,7 +453,33 @@ def patched_inline_call(self_): model_output = self.forward(*args, **kwargs) return model_output + def save_aot_compiled_function(self): + """Save the AOT compiled function after warmup is complete.""" + if not envs.VLLM_USE_AOT_COMPILE: + return + + if getattr(self, "aot_compiled_fn", None) is None: + logger.debug("No AOT compiled function to save") + return + + if getattr(self, "aot_compile_loaded_from_cache", False): + logger.debug("AOT compiled function was loaded from cache, skipping save") + return + + aot_compilation_path = getattr(self, "_aot_compilation_path", None) + cache_dir = getattr(self, "_aot_cache_dir", None) + + if aot_compilation_path is None or cache_dir is None: + logger.debug("No AOT compilation path found, skipping save") + return + + logger.info("Saving AOT compiled function to %s", aot_compilation_path) + os.makedirs(cache_dir, exist_ok=True) + self.aot_compiled_fn.save_compiled_function(aot_compilation_path) + logger.info("AOT compiled function saved successfully") + cls.__call__ = __call__ + cls.save_aot_compiled_function = save_aot_compiled_function return cls diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..0b56fd706acf 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -2,10 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import io +import pickle from collections.abc import Callable +from pickle import Pickler from typing import Any +import torch._functorch.config import torch.fx as fx +from torch._inductor.runtime.triton_heuristics import CachingAutotuner import vllm.envs as envs from vllm.compilation.backends import VllmBackend @@ -33,6 +38,8 @@ def __init__( sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, + get_compiled_graph_for_size: Callable | None = None, + returns_tuple: bool | None = None, ): """ The backend for piecewise compilation. @@ -49,19 +56,20 @@ def __init__( self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend + self.get_compiled_graph_for_size = get_compiled_graph_for_size self.is_first_graph = piecewise_compile_index == 0 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes.copy()) self.first_run_finished = False self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - self.sym_shape_indices = sym_shape_indices + self.returns_tuple = returns_tuple self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" @@ -79,6 +87,71 @@ def __init__( runnable=self.compiled_graph_for_general_shape, ) + self.populate_precompiled_entries() + + def populate_precompiled_entries(self): + if self.get_compiled_graph_for_size is None: + return + + for shape, entry in self.concrete_size_entries.items(): + if entry.compiled: + continue + entry.runnable = self.get_compiled_graph_wrapper( + self.get_compiled_graph_for_size(str(shape)) + ) + entry.compiled = True + logger.debug( + "setting runnable for shape %s to precompiled graph wrapper", shape + ) + self.to_be_compiled_sizes.remove(shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + def get_compiled_graph_wrapper(self, compiled_graph): + from torch._inductor.compile_fx import graph_returns_tuple + + # For deserialized functions from cache, the graph might be + # empty. In that case, we can't check graph_returns_tuple + # from the graph itself. Use the stored returns_tuple value. + if not self.graph.graph.nodes: + # Empty graph - use stored returns_tuple value + if self.returns_tuple is None: + # No stored value, assume it returns a tuple and just pass through + def compiled_graph_wrapper_for_cache(*args): + return compiled_graph(*args) + + return compiled_graph_wrapper_for_cache + else: + # Use the stored returns_tuple value + returns_tuple = self.returns_tuple + + def compiled_graph_wrapper_with_tuple(*args): + graph_output = compiled_graph(*args) + if returns_tuple: + return graph_output + else: + # Don't unpack - the AOTCompiledArtifact is returning a list + # but something else in the call chain expects to unpack it + return graph_output + + return compiled_graph_wrapper_with_tuple + + returns_tuple = graph_returns_tuple(self.graph) + + def compiled_graph_wrapper(*args): + graph_output = compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple or not isinstance(graph_output, (tuple, list)): + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile @@ -86,9 +159,61 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def to_bytes(self) -> dict[str, bytes]: + if not hasattr(self.compiled_graph_for_general_shape, "serialize"): + return {} + + class InductorCompiledArtifactsPickler(Pickler): + def reducer_override(self, obj): + if isinstance(obj, CachingAutotuner): + obj.prepare_for_pickle() + return pickle.loads, ( + pickle.dumps( + obj, + ), + ) + return NotImplemented + + def serialize(fn) -> bytes: + assert hasattr(fn, "serialize"), "fn must have serialize method" + with torch._functorch.config.patch("bundled_autograd_cache", True): + entry = fn.serialize() + # entry.pre_save() + + f = io.BytesIO() + InductorCompiledArtifactsPickler(f).dump(entry) + result = f.getvalue() + return result + + out = {"None": serialize(self.compiled_graph_for_general_shape)} + + for entry in self.concrete_size_entries.values(): + if not entry.compiled: + logger.debug( + "entry with shape %s not compiled, so cannot get its bytes", + entry.runtime_shape, + ) + continue + out[str(entry.runtime_shape)] = serialize(entry.runnable) + + return out + def __call__(self, *args) -> Any: + logger.debug( + "calling piecewise backend on runtime_shape %s with " + "remaining compile sizes %s", + args[self.sym_shape_indices[0]], + self.to_be_compiled_sizes, + ) + if not self.first_run_finished: self.first_run_finished = True + # Always wrap the general shape graph on first run if it has a + # serialize method (meaning it's a deserialized function from cache) + if hasattr(self.compiled_graph_for_general_shape, "serialize"): + self.compiled_graph_for_general_shape = self.get_compiled_graph_wrapper( + self.compiled_graph_for_general_shape + ) self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) @@ -101,19 +226,30 @@ def __call__(self, *args) -> Any: entry = self.concrete_size_entries[runtime_shape] if not entry.compiled: + assert self.get_compiled_graph_for_size is None entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, + + shape_arg = args[self.sym_shape_indices[0]] + logger.debug( + "compiling runnable for piecewise backend on " + "runtime_shape %s with remaining compile sizes %s", + shape_arg, + self.to_be_compiled_sizes, ) + with torch._functorch.config.patch("bundled_autograd_cache", True): + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape, + ) + # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4b10c85209f6..adccdb83f109 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -35,11 +35,13 @@ def __init__( ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config + self.backend = None if compiled_callable is None: # default compilation settings # compiling the forward method backend = vllm_config.compilation_config.init_backend(vllm_config) + self.backend = backend options = None if isinstance(backend, str) and backend == "inductor": options = ( diff --git a/vllm/envs.py b/vllm/envs.py index 0c45f93ec057..f58dede72550 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -91,6 +91,7 @@ VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False VLLM_USE_AOT_COMPILE: bool = False VLLM_FORCE_AOT_LOAD: bool = False + VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False VLLM_USE_TRITON_AWQ: bool = False @@ -253,6 +254,12 @@ def use_aot_compile() -> bool: return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" +def use_backend_with_cache() -> bool: + return ( + os.environ.get("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "0") == "1" + ) + + def env_with_choices( env_name: str, default: str | None, @@ -520,6 +527,10 @@ def get_vllm_port() -> int | None: # to load will result in a hard error when this is enabled. # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. "VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1", + # Enable the new VllmBackendWithCache backend that reconstructs + # compiled models directly from cached inductor artifacts without + # re-splitting graph modules. This reduces overhead during model loading. + "VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS": use_backend_with_cache, # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a110ad54a05e..f4217530691c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2989,6 +2989,38 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None + def save_compiler_cache(self) -> None: + """Save the compiler cache after warmup is complete.""" + from vllm.compilation.wrapper import ( + TorchCompileWrapperWithCustomDispatcher, + ) + + # The compiled model is typically at get_model().model for most models + # or get_model().get_language_model().model for multimodal models + wrapper = None + model = self.get_model() + + # Try to find the TorchCompileWrapperWithCustomDispatcher instance + if hasattr(model, "model") and isinstance( + model.model, TorchCompileWrapperWithCustomDispatcher + ): + wrapper = model.model + elif hasattr(model, "get_language_model"): + language_model = model.get_language_model() + if hasattr(language_model, "model") and isinstance( + language_model.model, TorchCompileWrapperWithCustomDispatcher + ): + wrapper = language_model.model + + if wrapper is None: + logger.debug("Model not compiled (wrapper is None), skipping cache save") + return + + # Save both AOT compiled function and compiler manager cache + from vllm.compilation.decorators import save_compile_cache + + save_compile_cache(wrapper) + def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 29b6532e4366..17579953649f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,6 +14,7 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -354,23 +355,48 @@ def compile_or_warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. - warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: - warmup_sizes = [ - x - for x in warmup_sizes - if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes - ] - # We skip EPLB here since we don't want to record dummy metrics - for size in sorted(warmup_sizes, reverse=True): - logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False) + + aot_compile_loaded_from_cache = False + if envs.VLLM_USE_AOT_COMPILE: + from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher + + if self.model_config.is_multimodal_model: + compiled_model = self.get_model().get_language_model().model + else: + compiled_model = self.get_model().model + if isinstance( + compiled_model, TorchCompileWrapperWithCustomDispatcher + ) and hasattr(compiled_model, "aot_compile_loaded_from_cache"): + aot_compile_loaded_from_cache = ( + compiled_model.aot_compile_loaded_from_cache + ) + + if not aot_compile_loaded_from_cache: + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if ( + not self.model_config.enforce_eager + and self.vllm_config.compilation_config.cudagraph_mode + != CUDAGraphMode.NONE + ): + warmup_sizes = [ + x + for x in warmup_sizes + if x + not in self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + # We skip EPLB here since we don't want to record dummy metrics + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False) self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) # Warmup and tune the kernels used during model execution before # cuda graph capture. kernel_warmup(self) + # Save compiler cache after warmup is complete + self.model_runner.save_compiler_cache() + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model()