Skip to content

Commit cac4c10

Browse files
[BUG] Make 'binary' default option for saving torch compile artifacts when using standalone_compile (vllm-project#27616)
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
1 parent f7d2946 commit cac4c10

File tree

5 files changed

+43
-5
lines changed

5 files changed

+43
-5
lines changed

docs/design/torch_compile.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ With all these factors taken into consideration, usually we can guarantee that t
2727

2828
A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes.
2929

30+
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
31+
3032
## Python Code Compilation
3133

3234
In the very verbose logs, we can see:

vllm/compilation/backends.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
5151
and hasattr(torch._inductor, "standalone_compile")
5252
):
5353
logger.debug("Using InductorStandaloneAdaptor")
54-
return InductorStandaloneAdaptor()
54+
return InductorStandaloneAdaptor(
55+
compilation_config.compile_cache_save_format
56+
)
5557
else:
5658
logger.debug("Using InductorAdaptor")
5759
return InductorAdaptor()

vllm/compilation/compiler_interface.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
from collections.abc import Callable
88
from contextlib import ExitStack
9-
from typing import Any
9+
from typing import Any, Literal
1010
from unittest.mock import patch
1111

1212
import torch
@@ -175,6 +175,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
175175

176176
name = "inductor_standalone"
177177

178+
def __init__(self, save_format: Literal["binary", "unpacked"]):
179+
self.save_format = save_format
180+
178181
def compute_hash(self, vllm_config: VllmConfig) -> str:
179182
factors = get_inductor_factors()
180183
hash_str = hashlib.md5(
@@ -220,7 +223,7 @@ def compile(
220223
assert key is not None
221224
path = os.path.join(self.cache_dir, key)
222225
if not envs.VLLM_DISABLE_COMPILE_CACHE:
223-
compiled_graph.save(path=path, format="unpacked")
226+
compiled_graph.save(path=path, format=self.save_format)
224227
compilation_counter.num_compiled_artifacts_saved += 1
225228
return compiled_graph, (key, path)
226229

@@ -237,7 +240,7 @@ def load(
237240
assert isinstance(handle[1], str)
238241
path = handle[1]
239242
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
240-
path=path, format="unpacked"
243+
path=path, format=self.save_format
241244
)
242245
from torch._inductor.compile_fx import graph_returns_tuple
243246

vllm/config/compilation.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from collections.abc import Callable
88
from dataclasses import asdict, field
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, ClassVar
10+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
1111

1212
from pydantic import TypeAdapter, field_validator
1313
from pydantic.dataclasses import dataclass
1414

15+
import vllm.envs as envs
1516
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
1617
from vllm.config.utils import config
1718
from vllm.logger import init_logger
@@ -208,6 +209,15 @@ class CompilationConfig:
208209
"""The directory to store the compiled graph, to accelerate Inductor
209210
compilation. By default, it will use model-related information to generate
210211
a cache directory."""
212+
compile_cache_save_format: Literal["binary", "unpacked"] = field(
213+
default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT
214+
)
215+
"""Format for saving torch compile cache:\n
216+
- "binary": saves as binary file (multiprocess safe)\n
217+
- "unpacked": saves as directory structure for inspection/debugging
218+
(NOT multiprocess safe)\n
219+
Defaults to `VLLM_COMPILE_CACHE_SAVE_FORMAT` if not specified.
220+
"""
211221
backend: str = ""
212222
"""The backend for compilation. It needs to be a string:
213223
@@ -479,6 +489,7 @@ def compute_hash(self) -> str:
479489
factors.append(self.inductor_compile_config)
480490
factors.append(self.inductor_passes)
481491
factors.append(self.pass_config.uuid())
492+
factors.append(self.compile_cache_save_format)
482493
return hashlib.sha256(str(factors).encode()).hexdigest()
483494

484495
def __repr__(self) -> str:
@@ -520,6 +531,16 @@ def validate_cudagraph_mode_before(cls, value: Any) -> Any:
520531
return CUDAGraphMode[value.upper()]
521532
return value
522533

534+
@field_validator("compile_cache_save_format")
535+
@classmethod
536+
def validate_compile_cache_save_format(cls, value: str) -> str:
537+
if value not in ("binary", "unpacked"):
538+
raise ValueError(
539+
f"compile_cache_save_format must be 'binary' or 'unpacked', "
540+
f"got: {value}"
541+
)
542+
return value
543+
523544
def __post_init__(self) -> None:
524545
if self.level is not None:
525546
logger.warning(

vllm/envs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@
218218
VLLM_USE_FBGEMM: bool = False
219219
VLLM_GC_DEBUG: str = ""
220220
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
221+
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
221222

222223

223224
def get_default_cache_root():
@@ -1442,6 +1443,15 @@ def get_vllm_port() -> int | None:
14421443
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
14431444
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
14441445
),
1446+
# Format for saving torch.compile cache artifacts
1447+
# - "binary": saves as binary file
1448+
# Safe for multiple vllm serve processes accessing the same torch compile cache.
1449+
# - "unpacked": saves as directory structure (for inspection/debugging)
1450+
# NOT multiprocess safe - race conditions may occur with multiple processes.
1451+
# Allows viewing and setting breakpoints in Inductor's code output files.
1452+
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
1453+
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
1454+
),
14451455
}
14461456

14471457
# --8<-- [end:env-vars-definition]

0 commit comments

Comments
 (0)