Skip to content

Commit dd70437

Browse files
authored
Remove cuda hard-code in compute_causal_conv1d_metadata (vllm-project#25555)
Signed-off-by: Icey <1790571317@qq.com>
1 parent 99b3a50 commit dd70437

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
947947
nums_dict = {} # type: ignore
948948
batch_ptr = None
949949
token_chunk_offset_ptr = None
950+
device = query_start_loc_p.device
950951
for BLOCK_M in [8]: # cover all BLOCK_M values
951952
nums = -(-seqlens // BLOCK_M)
952953
nums_dict[BLOCK_M] = {}
@@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
968969
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
969970
PAD_SLOT_ID,
970971
dtype=torch.int32,
971-
device='cuda')
972+
device=device)
972973
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
973974
PAD_SLOT_ID,
974975
dtype=torch.int32,
975-
device='cuda')
976+
device=device)
976977
else:
977978
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
978979
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)

0 commit comments

Comments
 (0)