Skip to content

Commit e35a098

Browse files
mgoinrtourgeman
authored andcommitted
[Bugfix] Use PIECEWISE cudagraphs on Blackwell if max_model_len > 131072 (vllm-project#27114)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent e646a85 commit e35a098

File tree

1 file changed

+37
-15
lines changed

1 file changed

+37
-15
lines changed

vllm/config/vllm.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -350,25 +350,47 @@ def __post_init__(self):
350350
self.compilation_config.cudagraph_mode = (
351351
CUDAGraphMode.FULL_AND_PIECEWISE
352352
)
353+
else:
354+
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
353355

354-
# pooling models and encoder-decoder models
355-
# do not support full cudagraphs
356-
if self.model_config is not None and (
357-
self.model_config.pooler_config is not None
358-
or self.model_config.is_encoder_decoder
359-
):
356+
# if cudagraph_mode has full cudagraphs, we need to check support
357+
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
358+
# decode context parallel does not support full cudagraphs
359+
if self.parallel_config.decode_context_parallel_size > 1:
360+
logger.warning_once(
361+
"Decode context parallel (DCP) is enabled, which is "
362+
"incompatible with full CUDA graphs. "
363+
"Overriding cudagraph_mode to PIECEWISE."
364+
)
365+
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
366+
elif self.model_config is not None:
367+
if self.model_config.pooler_config is not None:
368+
logger.warning_once(
369+
"Pooling models do not support full cudagraphs. "
370+
"Overriding cudagraph_mode to PIECEWISE."
371+
)
360372
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
361-
362-
# decode context parallel do not support full cudagraphs now.
363-
if self.parallel_config.decode_context_parallel_size > 1:
364-
logger.warning(
365-
"Decode context parallel (DCP) is enabled, which is "
366-
"incompatible with full CUDA graphs. Set "
367-
"cudagraph_mode to PIECEWISE."
373+
elif self.model_config.is_encoder_decoder:
374+
logger.warning_once(
375+
"Encoder-decoder models do not support full cudagraphs. "
376+
"Overriding cudagraph_mode to PIECEWISE."
377+
)
378+
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
379+
elif (
380+
current_platform.is_cuda()
381+
and current_platform.is_device_capability(100)
382+
and self.model_config.max_model_len > 131072
383+
and not self.model_config.use_mla
384+
):
385+
# Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
386+
logger.warning_once(
387+
"NVIDIA Blackwell TRTLLM attention cannot support "
388+
"max_model_len >= 131072 (found "
389+
f"{self.model_config.max_model_len}), causing dynamic "
390+
"dispatching that breaks full cudagraphs. "
391+
"Overriding cudagraph_mode to PIECEWISE."
368392
)
369393
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
370-
else:
371-
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
372394

373395
# disable cudagraph when enforce eager execution
374396
if self.model_config is not None and self.model_config.enforce_eager:

0 commit comments

Comments
 (0)