Skip to content

Commit 173b356

Browse files
authored
[PERF] Remove TRTLLM Gen attn kernel limitation max_seq_len <=131072 (#28755)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
1 parent 638e419 commit 173b356

File tree

2 files changed

+2
-19
lines changed

2 files changed

+2
-19
lines changed

vllm/config/vllm.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -483,21 +483,6 @@ def __post_init__(self):
483483
"Overriding cudagraph_mode to PIECEWISE."
484484
)
485485
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
486-
elif (
487-
current_platform.is_cuda()
488-
and current_platform.is_device_capability(100)
489-
and self.model_config.max_model_len > 131072
490-
and not self.model_config.use_mla
491-
):
492-
# Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
493-
logger.warning_once(
494-
"NVIDIA Blackwell TRTLLM attention cannot support "
495-
"max_model_len >= 131072 (found "
496-
f"{self.model_config.max_model_len}), causing dynamic "
497-
"dispatching that breaks full cudagraphs. "
498-
"Overriding cudagraph_mode to PIECEWISE."
499-
)
500-
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
501486

502487
# disable cudagraph when enforce eager execution
503488
if self.model_config is not None and self.model_config.enforce_eager:

vllm/utils/flashinfer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,14 +319,12 @@ def use_trtllm_attention(
319319
# Environment variable not set - use auto-detection
320320
if is_prefill:
321321
# Prefill auto-detection
322-
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
322+
use_trtllm = kv_cache_dtype == "auto"
323323
if use_trtllm:
324324
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
325325
else:
326326
# Decode auto-detection
327-
use_trtllm = (
328-
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
329-
)
327+
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
330328
if use_trtllm:
331329
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
332330
return use_trtllm

0 commit comments

Comments
 (0)