File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
vllm/v1/attention/backends Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
947947 nums_dict = {} # type: ignore
948948 batch_ptr = None
949949 token_chunk_offset_ptr = None
950+ device = query_start_loc_p .device
950951 for BLOCK_M in [8 ]: # cover all BLOCK_M values
951952 nums = - (- seqlens // BLOCK_M )
952953 nums_dict [BLOCK_M ] = {}
@@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
968969 batch_ptr = torch .full ((MAX_NUM_PROGRAMS , ),
969970 PAD_SLOT_ID ,
970971 dtype = torch .int32 ,
971- device = 'cuda' )
972+ device = device )
972973 token_chunk_offset_ptr = torch .full ((MAX_NUM_PROGRAMS , ),
973974 PAD_SLOT_ID ,
974975 dtype = torch .int32 ,
975- device = 'cuda' )
976+ device = device )
976977 else :
977978 if batch_ptr .nelement () < MAX_NUM_PROGRAMS :
978979 batch_ptr .resize_ (MAX_NUM_PROGRAMS ).fill_ (PAD_SLOT_ID )
You can’t perform that action at this time.
0 commit comments