@@ -350,25 +350,47 @@ def __post_init__(self):
350350 self .compilation_config .cudagraph_mode = (
351351 CUDAGraphMode .FULL_AND_PIECEWISE
352352 )
353+ else :
354+ self .compilation_config .cudagraph_mode = CUDAGraphMode .NONE
353355
354- # pooling models and encoder-decoder models
355- # do not support full cudagraphs
356- if self .model_config is not None and (
357- self .model_config .pooler_config is not None
358- or self .model_config .is_encoder_decoder
359- ):
356+ # if cudagraph_mode has full cudagraphs, we need to check support
357+ if self .compilation_config .cudagraph_mode .has_full_cudagraphs ():
358+ # decode context parallel does not support full cudagraphs
359+ if self .parallel_config .decode_context_parallel_size > 1 :
360+ logger .warning_once (
361+ "Decode context parallel (DCP) is enabled, which is "
362+ "incompatible with full CUDA graphs. "
363+ "Overriding cudagraph_mode to PIECEWISE."
364+ )
365+ self .compilation_config .cudagraph_mode = CUDAGraphMode .PIECEWISE
366+ elif self .model_config is not None :
367+ if self .model_config .pooler_config is not None :
368+ logger .warning_once (
369+ "Pooling models do not support full cudagraphs. "
370+ "Overriding cudagraph_mode to PIECEWISE."
371+ )
360372 self .compilation_config .cudagraph_mode = CUDAGraphMode .PIECEWISE
361-
362- # decode context parallel do not support full cudagraphs now.
363- if self .parallel_config .decode_context_parallel_size > 1 :
364- logger .warning (
365- "Decode context parallel (DCP) is enabled, which is "
366- "incompatible with full CUDA graphs. Set "
367- "cudagraph_mode to PIECEWISE."
373+ elif self .model_config .is_encoder_decoder :
374+ logger .warning_once (
375+ "Encoder-decoder models do not support full cudagraphs. "
376+ "Overriding cudagraph_mode to PIECEWISE."
377+ )
378+ self .compilation_config .cudagraph_mode = CUDAGraphMode .PIECEWISE
379+ elif (
380+ current_platform .is_cuda ()
381+ and current_platform .is_device_capability (100 )
382+ and self .model_config .max_model_len > 131072
383+ and not self .model_config .use_mla
384+ ):
385+ # Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
386+ logger .warning_once (
387+ "NVIDIA Blackwell TRTLLM attention cannot support "
388+ "max_model_len >= 131072 (found "
389+ f"{ self .model_config .max_model_len } ), causing dynamic "
390+ "dispatching that breaks full cudagraphs. "
391+ "Overriding cudagraph_mode to PIECEWISE."
368392 )
369393 self .compilation_config .cudagraph_mode = CUDAGraphMode .PIECEWISE
370- else :
371- self .compilation_config .cudagraph_mode = CUDAGraphMode .NONE
372394
373395 # disable cudagraph when enforce eager execution
374396 if self .model_config is not None and self .model_config .enforce_eager :
0 commit comments