Skip to content

Commit 8da2f28

Browse files
authored
[ROCm][BugFix]Fix get_cu_count in rocm_aiter_fa.py (#28618)
Signed-off-by: ganyi <ygan@amd.com>
1 parent 86d15bf commit 8da2f28

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.logger import init_logger
1919
from vllm.platforms import current_platform
2020
from vllm.utils.math_utils import cdiv
21+
from vllm.utils.platform_utils import get_cu_count
2122
from vllm.v1.attention.backends.utils import (
2223
AttentionCGSupport,
2324
AttentionMetadataBuilder,
@@ -38,7 +39,7 @@ def block_size(x, head_dim):
3839
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
3940

4041
def num_programs(total_tokens):
41-
return min(total_tokens, current_platform.get_cu_count())
42+
return min(total_tokens, get_cu_count())
4243

4344
@triton.jit
4445
def cp_mha_gather_cache_kernel(

0 commit comments

Comments
 (0)