1616
1717#include < algorithm>
1818#include < cassert>
19- #include < map>
20- #include < vector>
19+ #include < cfloat> // FLT_MIN
2120
2221#ifdef USE_ROCM
2322 #include < hip/hip_bf16.h>
@@ -209,6 +208,20 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
209208
210209namespace vllm {
211210
211+ // Used to copy/convert one element
212+ template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
213+ struct CopyWithScaleOp {
214+ float scale;
215+
216+ __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
217+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
218+ dst = static_cast <OutT>(src);
219+ } else {
220+ dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
221+ }
222+ }
223+ };
224+
212225template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
213226__global__ void reshape_and_cache_kernel (
214227 const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
@@ -224,58 +237,50 @@ __global__ void reshape_and_cache_kernel(
224237 const int64_t token_idx = blockIdx .x ;
225238 const int64_t slot_idx = slot_mapping[token_idx];
226239 if (slot_idx < 0 ) {
227- // Padding token that should be ignored.
228240 return ;
229241 }
230242
231243 const int64_t block_idx = slot_idx / block_size;
232244 const int64_t block_offset = slot_idx % block_size;
245+ const int h_block_count = head_size / x; // head_size//x
233246
234- const int n = num_heads * head_size;
235- for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
236- const int64_t src_key_idx = token_idx * key_stride + i;
237- const int64_t src_value_idx = token_idx * value_stride + i;
238-
239- const int head_idx = i / head_size;
240- const int head_offset = i % head_size;
241- const int x_idx = head_offset / x;
242- const int x_offset = head_offset % x;
243-
244- const int64_t tgt_key_idx =
245- block_idx * num_heads * (head_size / x) * block_size * x +
246- head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
247- block_offset * x + x_offset;
248- const int64_t tgt_value_idx =
249- block_idx * num_heads * head_size * block_size +
250- head_idx * head_size * block_size + head_offset * block_size +
251- block_offset;
252- scalar_t tgt_key = key[src_key_idx];
253- scalar_t tgt_value = value[src_value_idx];
254- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
255- key_cache[tgt_key_idx] = tgt_key;
256- value_cache[tgt_value_idx] = tgt_value;
257- } else {
258- key_cache[tgt_key_idx] =
259- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, *k_scale);
260- value_cache[tgt_value_idx] =
261- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, *v_scale);
262- }
247+ const int h_block_idx = threadIdx .x ;
248+ if (h_block_idx >= num_heads * h_block_count) {
249+ return ;
263250 }
264- }
265251
266- // Used by vectorization_utils to copy/convert one element
267- template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
268- struct CopyWithScaleOp {
269- float scale;
252+ const int head_idx = h_block_idx / h_block_count;
253+ const int h_block = h_block_idx % h_block_count;
270254
271- __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
272- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
273- dst = static_cast <OutT>(src);
274- } else {
275- dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
276- }
255+ const scalar_t * __restrict__ key_src =
256+ key + token_idx * key_stride + head_idx * head_size + h_block * x;
257+ const int64_t src_value_start =
258+ token_idx * value_stride + head_idx * head_size + h_block * x;
259+
260+ cache_t * __restrict__ key_dst =
261+ key_cache + block_idx * num_heads * h_block_count * block_size * x +
262+ head_idx * h_block_count * block_size * x + h_block * block_size * x +
263+ block_offset * x;
264+ const int64_t tgt_value_start =
265+ block_idx * num_heads * h_block_count * x * block_size +
266+ head_idx * h_block_count * x * block_size + h_block * x * block_size +
267+ block_offset;
268+
269+ constexpr int VEC_SIZE = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ;
270+ float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *k_scale;
271+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> k_op{k_scale_val};
272+ float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *v_scale;
273+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> v_op{v_scale_val};
274+
275+ vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0 , 1 , k_op);
276+
277+ const scalar_t * __restrict__ value_src = value + src_value_start;
278+ cache_t * __restrict__ value_dst = value_cache + tgt_value_start;
279+ #pragma unroll
280+ for (int i = 0 ; i < x; i++) {
281+ v_op (value_dst[i * block_size], value_src[i]);
277282 }
278- };
283+ }
279284
280285template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
281286__global__ void reshape_and_cache_flash_kernel (
@@ -601,9 +606,10 @@ void reshape_and_cache(
601606
602607 int key_stride = key.stride (0 );
603608 int value_stride = value.stride (0 );
609+ int head_div_x = head_size / x;
604610
605611 dim3 grid (num_tokens);
606- dim3 block (std::min (num_heads * head_size , 512 ));
612+ dim3 block (std::min (num_heads * head_div_x , 512 ));
607613 const at::cuda::OptionalCUDAGuard device_guard (device_of (key));
608614 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
609615
0 commit comments