Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,25 +350,47 @@ def __post_init__(self):
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE
)
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

# pooling models and encoder-decoder models
# do not support full cudagraphs
if self.model_config is not None and (
self.model_config.pooler_config is not None
or self.model_config.is_encoder_decoder
):
# if cudagraph_mode has full cudagraphs, we need to check support
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
if self.parallel_config.decode_context_parallel_size > 1:
logger.warning_once(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# decode context parallel do not support full cudagraphs now.
if self.parallel_config.decode_context_parallel_size > 1:
logger.warning(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. Set "
"cudagraph_mode to PIECEWISE."
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and self.model_config.max_model_len > 131072
and not self.model_config.use_mla
):
# Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
logger.warning_once(
"NVIDIA Blackwell TRTLLM attention cannot support "
"max_model_len >= 131072 (found "
f"{self.model_config.max_model_len}), causing dynamic "
"dispatching that breaks full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
Expand Down