Skip to content

Commit 1ffe934

Browse files
vnadathurWorldExploredProExpertProg
authored
[torch.compile] caching of config fields should be opt-out by default (#26468)
Signed-off-by: vnadathur <glvikramn@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: WorldExplored <srreyansh.sethi@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+worldexplored@users.noreply.github.com> Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent 2c8b918 commit 1ffe934

File tree

11 files changed

+602
-193
lines changed

11 files changed

+602
-193
lines changed

tests/config/test_config_utils.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from dataclasses import dataclass
5+
from enum import Enum
6+
7+
import pytest
8+
9+
from vllm.config.utils import get_hash_factors, hash_factors, normalize_value
10+
11+
# Helpers
12+
13+
14+
def endswith_fqname(obj, suffix: str) -> bool:
15+
# normalize_value(type) returns fully-qualified name
16+
# Compare suffix to avoid brittle import paths.
17+
out = normalize_value(obj)
18+
return isinstance(out, str) and out.endswith(suffix)
19+
20+
21+
def expected_path(p_str: str = ".") -> str:
22+
import pathlib
23+
24+
p = pathlib.Path(p_str)
25+
return p.expanduser().resolve().as_posix()
26+
27+
28+
# Minimal dataclass to test get_hash_factors.
29+
# Avoid importing heavy vLLM configs.
30+
@dataclass
31+
class SimpleConfig:
32+
a: object
33+
b: object | None = None
34+
35+
36+
class DummyLogprobsMode(Enum):
37+
RAW_LOGITS = "raw_logits"
38+
39+
40+
def test_hash_factors_deterministic():
41+
"""Test that hash_factors produces consistent SHA-256 hashes"""
42+
factors = {"a": 1, "b": "test"}
43+
hash1 = hash_factors(factors)
44+
hash2 = hash_factors(factors)
45+
46+
assert hash1 == hash2
47+
# Dict key insertion order should not affect the hash.
48+
factors_reordered = {"b": "test", "a": 1}
49+
assert hash_factors(factors_reordered) == hash1
50+
assert len(hash1) == 64
51+
assert all(c in "0123456789abcdef" for c in hash1)
52+
53+
54+
@pytest.mark.parametrize(
55+
"inp, expected",
56+
[
57+
(None, None),
58+
(True, True),
59+
(1, 1),
60+
(1.0, 1.0),
61+
("x", "x"),
62+
(b"ab", "6162"),
63+
(bytearray(b"ab"), "6162"),
64+
([1, 2], (1, 2)),
65+
({"b": 2, "a": 1}, (("a", 1), ("b", 2))),
66+
],
67+
)
68+
def test_normalize_value_matrix(inp, expected):
69+
"""Parametric input→expected normalization table."""
70+
assert normalize_value(inp) == expected
71+
72+
73+
def test_normalize_value_enum():
74+
# Enums normalize to (module.QualName, value).
75+
# DummyLogprobsMode uses a string payload.
76+
out = normalize_value(DummyLogprobsMode.RAW_LOGITS)
77+
assert isinstance(out, tuple)
78+
assert out[0].endswith("DummyLogprobsMode")
79+
# Expect string payload 'raw_logits'.
80+
assert out[1] == "raw_logits"
81+
82+
83+
def test_normalize_value_set_order_insensitive():
84+
# Sets are unordered; normalize_value sorts elements for determinism.
85+
assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3})
86+
87+
88+
def test_normalize_value_path_normalization():
89+
from pathlib import Path # local import to avoid global dependency
90+
91+
# Paths expand/resolve to absolute strings.
92+
# Stabilizes hashing across working dirs.
93+
assert normalize_value(Path(".")) == expected_path(".")
94+
95+
96+
def test_normalize_value_uuid_and_to_json():
97+
# Objects may normalize via uuid() or to_json_string().
98+
class HasUUID:
99+
def uuid(self):
100+
return "test-uuid"
101+
102+
class ToJson:
103+
def to_json_string(self):
104+
return '{"x":1}'
105+
106+
assert normalize_value(HasUUID()) == "test-uuid"
107+
assert normalize_value(ToJson()) == '{"x":1}'
108+
109+
110+
@pytest.mark.parametrize(
111+
"bad",
112+
[
113+
(lambda x: x),
114+
(type("CallableInstance", (), {"__call__": lambda self: 0}))(),
115+
(lambda: (lambda: 0))(), # nested function instance
116+
],
117+
)
118+
def test_error_cases(bad):
119+
"""Inputs expected to raise TypeError."""
120+
# Reject functions/lambdas/callable instances
121+
# to avoid under-hashing.
122+
with pytest.raises(TypeError):
123+
normalize_value(bad)
124+
125+
126+
def test_enum_vs_int_disambiguation():
127+
# int stays primitive
128+
nf_int = normalize_value(1)
129+
assert nf_int == 1
130+
131+
# enum becomes ("module.QualName", value)
132+
nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS)
133+
assert isinstance(nf_enum, tuple) and len(nf_enum) == 2
134+
enum_type, enum_val = nf_enum
135+
assert enum_type.endswith(".DummyLogprobsMode")
136+
assert enum_val == "raw_logits"
137+
138+
# Build factor dicts from configs with int vs enum
139+
f_int = get_hash_factors(SimpleConfig(1), set())
140+
f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set())
141+
# The int case remains a primitive value
142+
assert f_int["a"] == 1
143+
# The enum case becomes a tagged tuple ("module.QualName", "raw_logits")
144+
assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits"
145+
# Factor dicts must differ so we don't collide primitives with Enums.
146+
assert f_int != f_enum
147+
# Hash digests must differ correspondingly
148+
assert hash_factors(f_int) != hash_factors(f_enum)
149+
150+
# Hash functions produce stable hex strings
151+
h_int = hash_factors(f_int)
152+
h_enum = hash_factors(f_enum)
153+
assert isinstance(h_int, str) and len(h_int) == 64
154+
assert isinstance(h_enum, str) and len(h_enum) == 64
155+
156+
157+
def test_classes_are_types():
158+
"""Types normalize to FQNs; include real vLLM types."""
159+
# Only classes allowed; functions/lambdas are rejected.
160+
# Canonical form is the fully-qualified name.
161+
assert isinstance(normalize_value(str), str)
162+
163+
class LocalDummy:
164+
pass
165+
166+
assert endswith_fqname(LocalDummy, ".LocalDummy")

vllm/compilation/backends.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import ast
55
import dataclasses
66
import hashlib
7+
import json
78
import operator
89
import os
910
import pprint
1011
import time
1112
from collections.abc import Callable, Sequence
1213
from contextlib import contextmanager
14+
from functools import partial
1315
from typing import Any
1416

1517
import torch
@@ -23,7 +25,9 @@
2325
should_split,
2426
)
2527
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
28+
from vllm.config.utils import hash_factors
2629
from vllm.logger import init_logger
30+
from vllm.logging_utils import lazy
2731
from vllm.platforms import current_platform
2832
from vllm.utils.import_utils import resolve_obj_by_qualname
2933
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -580,35 +584,47 @@ def configure_post_pass(self):
580584
def __call__(
581585
self, graph: fx.GraphModule, example_inputs
582586
) -> VllmSerializableFunction:
583-
from .caching import _compute_code_hash, compilation_config_hash_factors
584-
585587
vllm_config = self.vllm_config
588+
# Minimal hashing here with existing utilities, reused below.
589+
590+
env_factors = envs.compile_factors()
591+
env_hash = hash_factors(env_factors)
592+
# Compute config/compiler/code hashes once and reuse
593+
config_hash = vllm_config.compute_hash()
594+
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
595+
forward_code_files = list(sorted(self.compilation_config.traced_files))
596+
597+
logger.debug(
598+
"Traced files (to be considered for compilation cache):\n%s",
599+
lazy(lambda: "\n".join(forward_code_files)),
600+
)
601+
hash_content = []
602+
for filepath in forward_code_files:
603+
hash_content.append(filepath)
604+
if filepath == "<string>":
605+
# This means the function was dynamically generated, with
606+
# e.g. exec(). We can't actually check these.
607+
continue
608+
try:
609+
with open(filepath) as f:
610+
hash_content.append(f.read())
611+
except Exception:
612+
logger.warning("Failed to read file %s", filepath)
613+
continue
614+
code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
615+
# Clear after consumption
616+
self.compilation_config.traced_files.clear()
586617
if not self.compilation_config.cache_dir:
587618
# no provided cache dir, generate one based on the known factors
588619
# that affects the compilation. if none of the factors change,
589620
# the cache dir will be the same so that we can reuse the compiled
590621
# graph.
591-
592-
factors = compilation_config_hash_factors(vllm_config)
593-
# 2. factors come from the code files that are traced by Dynamo (
594-
# it mainly summarizes how the model is used in forward pass)
595-
code_hash = _compute_code_hash(self.compilation_config.traced_files)
596-
self.compilation_config.traced_files.clear()
597-
factors.append(code_hash)
598-
599-
# 3. compiler hash
600-
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
601-
factors.append(compiler_hash)
602-
603-
# combine all factors to generate the cache dir
604-
hash_key = hashlib.md5(
605-
str(factors).encode(), usedforsecurity=False
606-
).hexdigest()[:10]
607-
622+
factors = [env_hash, config_hash, code_hash, compiler_hash]
623+
# Use SHA-256 for cache key hashing to be consistent across
624+
# compute_hash functions. Truncate for a short cache dir name.
625+
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
608626
cache_dir = os.path.join(
609-
envs.VLLM_CACHE_ROOT,
610-
"torch_compile_cache",
611-
hash_key,
627+
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
612628
)
613629
self.compilation_config.cache_dir = cache_dir
614630

@@ -621,6 +637,7 @@ def __call__(
621637
os.makedirs(local_cache_dir, exist_ok=True)
622638
self.compilation_config.local_cache_dir = local_cache_dir
623639

640+
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
624641
disable_cache = not is_compile_cache_enabled(
625642
self.compilation_config.inductor_compile_config
626643
)
@@ -638,6 +655,50 @@ def __call__(
638655
local_cache_dir, disable_cache, self.prefix
639656
)
640657

658+
# Reuses existing cache key
659+
660+
logger.debug(
661+
"torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
662+
env_hash,
663+
config_hash,
664+
compiler_hash,
665+
code_hash,
666+
local_cache_dir,
667+
)
668+
669+
# Persist and log only hash-relevant factors together.
670+
try:
671+
logger.debug(
672+
"Compile env factors (raw):\n%s\nVllm config hash: %s",
673+
lazy(partial(pprint.pformat, env_factors, width=120)),
674+
config_hash,
675+
)
676+
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
677+
if not os.path.exists(meta_path):
678+
with open(meta_path, "w") as f:
679+
json.dump(
680+
{
681+
"env": env_factors, # raw factors used for env_hash
682+
"config_hash": config_hash,
683+
"code_hash": code_hash,
684+
"compiler_hash": compiler_hash,
685+
},
686+
f,
687+
indent=2,
688+
sort_keys=True,
689+
)
690+
except Exception:
691+
# Best-effort only; metadata write failures are non-fatal.
692+
logger.warning(
693+
(
694+
"Could not write compile cache metadata at %s; continuing without "
695+
"metadata. Compiled cache remains valid; diagnostics may be "
696+
"limited."
697+
),
698+
local_cache_dir,
699+
exc_info=True,
700+
)
701+
641702
# when dynamo calls the backend, it means the bytecode
642703
# transform and analysis are done
643704
compilation_counter.num_graphs_seen += 1

vllm/compilation/pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def uuid(self):
127127
affects compilation caching. Its uuid depends on the UUIDs of all
128128
dependent passes and the pass config. See InductorPass for more info.
129129
"""
130-
state = {"pass_config": self.pass_config.uuid(), "passes": []}
130+
state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
131131
for pass_ in self.passes:
132132
state["passes"].append(pass_.uuid())
133133
state["passes"].append(self.fix_functionalization.uuid())

vllm/config/cache.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import hashlib
54
from dataclasses import field
65
from typing import TYPE_CHECKING, Any, Literal
76

@@ -160,13 +159,29 @@ def compute_hash(self) -> str:
160159
excluding anything before input ids/embeddings and after
161160
the final hidden states.
162161
"""
163-
factors: list[Any] = []
164-
factors.append(self.cache_dtype)
165-
factors.append(self.mamba_cache_dtype)
166-
factors.append(self.mamba_ssm_cache_dtype)
167-
# `cpu_offload_gb` does not use `torch.compile` yet.
168-
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
169-
return hash_str
162+
ignored_factors = {
163+
# Runtime/derived knobs that don't affect compiled graph shape
164+
"gpu_memory_utilization",
165+
"swap_space",
166+
"is_attention_free",
167+
"num_gpu_blocks_override",
168+
"enable_prefix_caching",
169+
"prefix_caching_hash_algo",
170+
# `cpu_offload_gb` does not use `torch.compile` yet.
171+
"cpu_offload_gb",
172+
"cpu_kvcache_space_bytes",
173+
"mamba_page_size_padded",
174+
# Post-init/derived counters
175+
"num_gpu_blocks",
176+
"num_cpu_blocks",
177+
# WIP feature toggle not impacting compiled graph shape
178+
"kv_sharing_fast_prefill",
179+
}
180+
181+
from vllm.config.utils import get_hash_factors, hash_factors
182+
183+
factors = get_hash_factors(self, ignored_factors)
184+
return hash_factors(factors)
170185

171186
def metrics_info(self):
172187
# convert cache_config to dict(key: str, value: str) for prometheus

0 commit comments

Comments
 (0)