diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 1e6e455210c8..d2d6c6c47de0 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -483,21 +483,6 @@ def __post_init__(self): "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif ( - current_platform.is_cuda() - and current_platform.is_device_capability(100) - and self.model_config.max_model_len > 131072 - and not self.model_config.use_mla - ): - # Refer to vllm/utils/flashinfer.py::use_trtllm_attention() - logger.warning_once( - "NVIDIA Blackwell TRTLLM attention cannot support " - "max_model_len >= 131072 (found " - f"{self.model_config.max_model_len}), causing dynamic " - "dispatching that breaks full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 79e5a4c30259..1209d64901bf 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -319,14 +319,12 @@ def use_trtllm_attention( # Environment variable not set - use auto-detection if is_prefill: # Prefill auto-detection - use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto" + use_trtllm = kv_cache_dtype == "auto" if use_trtllm: logger.warning_once("Using TRTLLM prefill attention (auto-detected).") else: # Decode auto-detection - use_trtllm = ( - num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" - ) + use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto" if use_trtllm: logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm