Skip to content

Commit c0ec818

Browse files
[torch.compile]: Add VLLM_DEBUG_DUMP_PATH environment variable (vllm-project#25651)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent b65e56b commit c0ec818

File tree

6 files changed

+44
-17
lines changed

6 files changed

+44
-17
lines changed

vllm/compilation/monitor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
54
import time
65

76
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
@@ -18,13 +17,12 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
1817
torch_compile_start_time = time.time()
1918

2019
compilation_config: CompilationConfig = vllm_config.compilation_config
21-
if compilation_config.level == CompilationLevel.PIECEWISE and \
22-
compilation_config.debug_dump_path:
20+
path = vllm_config.compile_debug_dump_path()
21+
if compilation_config.level == CompilationLevel.PIECEWISE and path:
2322
import depyf
24-
path = os.path.join(compilation_config.debug_dump_path,
25-
f"rank_{vllm_config.parallel_config.rank}")
23+
path.mkdir(parents=True, exist_ok=True)
2624
global context_manager
27-
context_manager = depyf.prepare_debug(path)
25+
context_manager = depyf.prepare_debug(path.as_posix())
2826
context_manager.__enter__()
2927

3028

vllm/compilation/vllm_inductor_pass.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import functools
44
import operator
55
import time
6-
from pathlib import Path
76
from typing import ClassVar, Optional
87

98
import regex as re
@@ -96,12 +95,10 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
9695
9796
TODO(luka): use pattern object to manually produce pattern graph
9897
"""
99-
debug_dump_path = config.compilation_config.debug_dump_path
98+
debug_dump_path = config.compile_debug_dump_path()
10099
if not debug_dump_path:
101100
return
102101

103-
rank = config.parallel_config.rank
104-
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
105102
debug_dump_path.mkdir(parents=True, exist_ok=True)
106103

107104
from vllm.utils import unique_filepath

vllm/compilation/wrapper.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,11 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
9292
return
9393

9494
self.compiled_codes.append(new_code)
95-
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
96-
if isinstance(debug_dump_dir, str) and debug_dump_dir != "":
97-
rank = self.vllm_config.parallel_config.rank
98-
decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}",
99-
"transformed_code.py")
100-
if not os.path.exists(decompiled_file):
95+
96+
path = self.vllm_config.compile_debug_dump_path()
97+
if path:
98+
decompiled_file = path / "transformed_code.py"
99+
if not decompiled_file.exists():
101100
try:
102101
# usually the decompilation will succeed for most models,
103102
# as we guarantee a full-graph compilation in Dynamo.

vllm/config/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from contextlib import contextmanager
1313
from dataclasses import field, fields, is_dataclass, replace
1414
from functools import cached_property, lru_cache
15+
from pathlib import Path
1516
from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar,
1617
Union, cast)
1718

@@ -541,6 +542,17 @@ def __post_init__(self):
541542
# local attention.
542543
self.scheduler_config.disable_hybrid_kv_cache_manager = True
543544

545+
if self.compilation_config.debug_dump_path:
546+
self.compilation_config.debug_dump_path = \
547+
self.compilation_config.debug_dump_path.absolute().expanduser()
548+
if envs.VLLM_DEBUG_DUMP_PATH is not None:
549+
env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
550+
if self.compilation_config.debug_dump_path:
551+
logger.warning(
552+
"Config-specified debug dump path is overridden"
553+
" by VLLM_DEBUG_DUMP_PATH to %s", env_path)
554+
self.compilation_config.debug_dump_path = env_path
555+
544556
def update_sizes_for_sequence_parallelism(self,
545557
possible_sizes: list) -> list:
546558
# remove the sizes that not multiple of tp_size when
@@ -672,6 +684,20 @@ def try_verify_and_update_config(self):
672684
f"but got '{self.load_config.load_format}'. "
673685
f"Model: {self.model_config.model}")
674686

687+
def compile_debug_dump_path(self) -> Optional[Path]:
688+
"""Returns a rank-aware path for dumping
689+
torch.compile debug information.
690+
"""
691+
if self.compilation_config.debug_dump_path is None:
692+
return None
693+
tp_rank = self.parallel_config.rank
694+
dp_rank = self.parallel_config.data_parallel_rank
695+
data_parallel_size = self.parallel_config.data_parallel_size
696+
append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \
697+
else f"rank_{tp_rank}_dp_{dp_rank}"
698+
path = self.compilation_config.debug_dump_path / append_path
699+
return path
700+
675701
def __str__(self):
676702
return (
677703
f"model={self.model_config.model!r}, "

vllm/config/compilation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import hashlib
66
from collections import Counter
77
from dataclasses import asdict, field
8+
from pathlib import Path
89
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
910

1011
from pydantic import TypeAdapter, field_validator
@@ -169,7 +170,7 @@ class CompilationConfig:
169170
- 1: dynamo as is.
170171
- 2: dynamo once.
171172
- 3: piecewise compilation."""
172-
debug_dump_path: str = ""
173+
debug_dump_path: Optional[Path] = None
173174
"""The path to dump the debug information."""
174175
cache_dir: str = ""
175176
"""The directory to store the compiled graph, to accelerate Inductor

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@
199199
VLLM_DBO_COMM_SMS: int = 20
200200
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
201201
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
202+
VLLM_DEBUG_DUMP_PATH: Optional[str] = None
202203
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
203204
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
204205
VLLM_USE_NCCL_SYMM_MEM: bool = False
@@ -513,6 +514,11 @@ def get_vllm_port() -> Optional[int]:
513514
"VLLM_PATTERN_MATCH_DEBUG":
514515
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
515516

517+
# Dump fx graphs to the given directory.
518+
# It will override CompilationConfig.debug_dump_path if set.
519+
"VLLM_DEBUG_DUMP_PATH":
520+
lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None),
521+
516522
# local rank of the process in the distributed setting, used to determine
517523
# the GPU device id
518524
"LOCAL_RANK":

0 commit comments

Comments
 (0)