From 322685c77f264edbbe8e5680cf3ce91c113c2008 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 17 Oct 2025 13:30:12 -0400 Subject: [PATCH 1/4] Use PIECEWISE cudagraphs on Blackwell if max_model_len > 131072 Signed-off-by: mgoin --- vllm/config/vllm.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7ee522ea9f0c..bba1773ebde8 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -370,6 +370,44 @@ def __post_init__(self): else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # if cudagraph_mode has full cudagraphs, we need to check supported + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + if self.model_config is not None: + if self.model_config.pooler_config is not None: + logger.warning_once( + "Pooling models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # encoder-decoder models do not support full cudagraphs + elif self.model_config.is_encoder_decoder: + logger.warning_once( + "Encoder-decoder models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and self.model_config.max_model_len > 131072 + ): + logger.warning_once( + "NVIDIA Blackwell TRTLLM attention cannot support " + "max_model_len >= 131072 (found " + f"{self.model_config.max_model_len}), causing dynamic " + "dispatching that breaks full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # decode context parallel do not support full cudagraphs + if self.parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: logger.info("Cudagraph is disabled under eager mode") From e83e7af03b299184d47f3689340c64f7a4f96fa5 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 17 Oct 2025 13:34:08 -0400 Subject: [PATCH 2/4] Cleanup Signed-off-by: mgoin --- vllm/config/vllm.py | 38 ++++++++++---------------------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index bba1773ebde8..92df0f1f9fa3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -350,36 +350,26 @@ def __post_init__(self): self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE ) - - # pooling models and encoder-decoder models - # do not support full cudagraphs - if self.model_config is not None and ( - self.model_config.pooler_config is not None - or self.model_config.is_encoder_decoder - ): - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - - # decode context parallel do not support full cudagraphs now. - if self.parallel_config.decode_context_parallel_size > 1: - logger.warning( - "Decode context parallel (DCP) is enabled, which is " - "incompatible with full CUDA graphs. Set " - "cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # if cudagraph_mode has full cudagraphs, we need to check supported + # if cudagraph_mode has full cudagraphs, we need to check support if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - if self.model_config is not None: + # decode context parallel does not support full cudagraphs + if self.parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif self.model_config is not None: if self.model_config.pooler_config is not None: logger.warning_once( "Pooling models do not support full cudagraphs. " "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - # encoder-decoder models do not support full cudagraphs elif self.model_config.is_encoder_decoder: logger.warning_once( "Encoder-decoder models do not support full cudagraphs. " @@ -399,14 +389,6 @@ def __post_init__(self): "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - # decode context parallel do not support full cudagraphs - if self.parallel_config.decode_context_parallel_size > 1: - logger.warning_once( - "Decode context parallel (DCP) is enabled, which is " - "incompatible with full CUDA graphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: From d901ca70d52181a2b9d9db6c8a7b55baf2b86dc1 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 17 Oct 2025 13:39:37 -0400 Subject: [PATCH 3/4] Add comment Signed-off-by: mgoin --- vllm/config/vllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 92df0f1f9fa3..310886548671 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -381,6 +381,7 @@ def __post_init__(self): and current_platform.is_device_capability(100) and self.model_config.max_model_len > 131072 ): + # Refer to vllm/utils/flashinfer.py::use_trtllm_attention() logger.warning_once( "NVIDIA Blackwell TRTLLM attention cannot support " "max_model_len >= 131072 (found " From 9a57a2202b4cc6459cd80de44ddf79bbd6b8f883 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 17 Oct 2025 13:57:04 -0400 Subject: [PATCH 4/4] Skip this case if MLA is used Signed-off-by: mgoin --- vllm/config/vllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 310886548671..fa7310f13b03 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -380,6 +380,7 @@ def __post_init__(self): current_platform.is_cuda() and current_platform.is_device_capability(100) and self.model_config.max_model_len > 131072 + and not self.model_config.use_mla ): # Refer to vllm/utils/flashinfer.py::use_trtllm_attention() logger.warning_once(