diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7ee522ea9f0c..fa7310f13b03 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -350,25 +350,47 @@ def __post_init__(self): self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE ) + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # pooling models and encoder-decoder models - # do not support full cudagraphs - if self.model_config is not None and ( - self.model_config.pooler_config is not None - or self.model_config.is_encoder_decoder - ): + # if cudagraph_mode has full cudagraphs, we need to check support + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + # decode context parallel does not support full cudagraphs + if self.parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif self.model_config is not None: + if self.model_config.pooler_config is not None: + logger.warning_once( + "Pooling models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - - # decode context parallel do not support full cudagraphs now. - if self.parallel_config.decode_context_parallel_size > 1: - logger.warning( - "Decode context parallel (DCP) is enabled, which is " - "incompatible with full CUDA graphs. Set " - "cudagraph_mode to PIECEWISE." + elif self.model_config.is_encoder_decoder: + logger.warning_once( + "Encoder-decoder models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and self.model_config.max_model_len > 131072 + and not self.model_config.use_mla + ): + # Refer to vllm/utils/flashinfer.py::use_trtllm_attention() + logger.warning_once( + "NVIDIA Blackwell TRTLLM attention cannot support " + "max_model_len >= 131072 (found " + f"{self.model_config.max_model_len}), causing dynamic " + "dispatching that breaks full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: