|
7 | 7 | from collections.abc import Callable |
8 | 8 | from dataclasses import asdict, field |
9 | 9 | from pathlib import Path |
10 | | -from typing import TYPE_CHECKING, Any, ClassVar |
| 10 | +from typing import TYPE_CHECKING, Any, ClassVar, Literal |
11 | 11 |
|
12 | 12 | from pydantic import TypeAdapter, field_validator |
13 | 13 | from pydantic.dataclasses import dataclass |
14 | 14 |
|
| 15 | +import vllm.envs as envs |
15 | 16 | from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass |
16 | 17 | from vllm.config.utils import config |
17 | 18 | from vllm.logger import init_logger |
@@ -208,6 +209,15 @@ class CompilationConfig: |
208 | 209 | """The directory to store the compiled graph, to accelerate Inductor |
209 | 210 | compilation. By default, it will use model-related information to generate |
210 | 211 | a cache directory.""" |
| 212 | + compile_cache_save_format: Literal["binary", "unpacked"] = field( |
| 213 | + default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT |
| 214 | + ) |
| 215 | + """Format for saving torch compile cache:\n |
| 216 | + - "binary": saves as binary file (multiprocess safe)\n |
| 217 | + - "unpacked": saves as directory structure for inspection/debugging |
| 218 | + (NOT multiprocess safe)\n |
| 219 | + Defaults to `VLLM_COMPILE_CACHE_SAVE_FORMAT` if not specified. |
| 220 | + """ |
211 | 221 | backend: str = "" |
212 | 222 | """The backend for compilation. It needs to be a string: |
213 | 223 |
|
@@ -479,6 +489,7 @@ def compute_hash(self) -> str: |
479 | 489 | factors.append(self.inductor_compile_config) |
480 | 490 | factors.append(self.inductor_passes) |
481 | 491 | factors.append(self.pass_config.uuid()) |
| 492 | + factors.append(self.compile_cache_save_format) |
482 | 493 | return hashlib.sha256(str(factors).encode()).hexdigest() |
483 | 494 |
|
484 | 495 | def __repr__(self) -> str: |
@@ -520,6 +531,16 @@ def validate_cudagraph_mode_before(cls, value: Any) -> Any: |
520 | 531 | return CUDAGraphMode[value.upper()] |
521 | 532 | return value |
522 | 533 |
|
| 534 | + @field_validator("compile_cache_save_format") |
| 535 | + @classmethod |
| 536 | + def validate_compile_cache_save_format(cls, value: str) -> str: |
| 537 | + if value not in ("binary", "unpacked"): |
| 538 | + raise ValueError( |
| 539 | + f"compile_cache_save_format must be 'binary' or 'unpacked', " |
| 540 | + f"got: {value}" |
| 541 | + ) |
| 542 | + return value |
| 543 | + |
523 | 544 | def __post_init__(self) -> None: |
524 | 545 | if self.level is not None: |
525 | 546 | logger.warning( |
|
0 commit comments