|
| 1 | +#include <ATen/cuda/CUDAContext.h> |
| 2 | +#include <torch/torch.h> |
| 3 | + |
| 4 | +#include <cute/tensor.hpp> |
| 5 | + |
| 6 | +#include "../dispatch.h" |
| 7 | + |
| 8 | +// Adapated from |
| 9 | +// https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/moe/moe_align_kernel.cu |
| 10 | + |
| 11 | +namespace llm::kernel::moe { |
| 12 | + |
| 13 | +namespace { |
| 14 | +constexpr int32_t WARP_SIZE = 32; |
| 15 | + |
| 16 | +// map p_idx to f_idx |
| 17 | +template <typename scalar_t> |
| 18 | +__global__ void row_id_map_kernel( |
| 19 | + const scalar_t* __restrict__ topk_ids, // [m, topk] |
| 20 | + scalar_t* __restrict__ sorted_token_idxes, // [n_padded_tokens+] |
| 21 | + scalar_t* __restrict__ cu_sum, // [n_experts+1] |
| 22 | + int n_tokens // m * topk |
| 23 | +) { |
| 24 | + const int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 25 | + const int stride = blockDim.x * gridDim.x; |
| 26 | + |
| 27 | + for (int i = tid; i < n_tokens; i += stride) { |
| 28 | + const auto e_idx = topk_ids[i]; |
| 29 | + // N.B. token ids for each expert is not sorted |
| 30 | + const auto p_idx = atomicAdd(&cu_sum[e_idx], 1); |
| 31 | + sorted_token_idxes[p_idx] = i; |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +template <typename scalar_t> |
| 36 | +__global__ void cusum_kernel( |
| 37 | + const scalar_t* __restrict__ topk_ids, // [n_tokens, topk] |
| 38 | + scalar_t* __restrict__ expert_ids, // [n_blocks+] |
| 39 | + scalar_t* __restrict__ n_padded_tokens, // [1] |
| 40 | + int n_experts, |
| 41 | + int block_size, |
| 42 | + size_t n_tokens, // n_tokens * topk |
| 43 | + scalar_t* __restrict__ cu_sum // [n_experts+1] |
| 44 | +) { |
| 45 | + using namespace cute; |
| 46 | + |
| 47 | + const int tid = threadIdx.x; |
| 48 | + const int stride = blockDim.x; |
| 49 | + |
| 50 | + const auto curr_expert = threadIdx.x; |
| 51 | + |
| 52 | + // token count for each expert [n_padded_experts] |
| 53 | + extern __shared__ int token_counts[]; |
| 54 | + |
| 55 | + // init token counts for each expert |
| 56 | + if (curr_expert < n_experts) { |
| 57 | + token_counts[curr_expert] = 0; |
| 58 | + } |
| 59 | + |
| 60 | + __syncthreads(); |
| 61 | + |
| 62 | + // process the token shard |
| 63 | + for (int i = tid; i < n_tokens; i += stride) { |
| 64 | + // accumulate token counts for each expert |
| 65 | + atomicAdd(&token_counts[topk_ids[i]], 1); |
| 66 | + } |
| 67 | + |
| 68 | + __syncthreads(); |
| 69 | + |
| 70 | + if (tid == 0) { |
| 71 | + cu_sum[0] = 0; |
| 72 | + for (int e_idx = 1; e_idx <= n_experts; ++e_idx) { |
| 73 | + cu_sum[e_idx] = cu_sum[e_idx - 1] + |
| 74 | + cute::round_up(token_counts[e_idx - 1], block_size); |
| 75 | + } |
| 76 | + *n_padded_tokens = cu_sum[n_experts]; |
| 77 | + } |
| 78 | + |
| 79 | + __syncthreads(); |
| 80 | + |
| 81 | + // update the expert id for each block |
| 82 | + if (curr_expert < n_experts) { |
| 83 | + for (int i = cu_sum[curr_expert]; i < cu_sum[curr_expert + 1]; |
| 84 | + i += block_size) { |
| 85 | + expert_ids[i / block_size] = curr_expert; |
| 86 | + } |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +template <typename scalar_t> |
| 91 | +__global__ void align_block_kernel( |
| 92 | + const scalar_t* __restrict__ topk_ids, // [m, topk] |
| 93 | + scalar_t* __restrict__ sorted_token_idxes, // [n_padded_tokens+] |
| 94 | + scalar_t* __restrict__ expert_ids, // [n_blocks+] |
| 95 | + scalar_t* __restrict__ n_padded_tokens, // [1] |
| 96 | + int n_experts, |
| 97 | + int block_size, |
| 98 | + int n_tokens // m * topk |
| 99 | +) { |
| 100 | + using namespace cute; |
| 101 | + |
| 102 | + // which shard and expert this thread would take care of |
| 103 | + const int n_shards = blockDim.x; |
| 104 | + const int curr_shard = threadIdx.x; |
| 105 | + const int curr_expert = threadIdx.x; |
| 106 | + |
| 107 | + const int tid = threadIdx.x; |
| 108 | + const int stride = blockDim.x; |
| 109 | + |
| 110 | + extern __shared__ int s_mem[]; |
| 111 | + |
| 112 | + // [n_experts+1] |
| 113 | + Tensor cu_sum = make_tensor(make_smem_ptr(reinterpret_cast<int*>(s_mem)), |
| 114 | + make_layout(make_shape(n_experts + 1))); |
| 115 | + // [n_shards+1][n_experts] |
| 116 | + // token_counts(0, _) = 0 is used to facilitate the prefix sum |
| 117 | + Tensor token_counts = |
| 118 | + make_tensor(make_smem_ptr(reinterpret_cast<int*>(s_mem + n_experts + 1)), |
| 119 | + make_layout(make_shape(n_shards + 1, n_experts))); |
| 120 | + |
| 121 | + // init token counts for each expert in the shard |
| 122 | + for (int e_idx = 0; e_idx < n_experts; ++e_idx) { |
| 123 | + token_counts(curr_shard + 1, e_idx) = 0; |
| 124 | + } |
| 125 | + |
| 126 | + // calculate expert counts for each token block |
| 127 | + for (int i = tid; i < n_tokens; i += stride) { |
| 128 | + ++token_counts(curr_shard + 1, topk_ids[i]); |
| 129 | + } |
| 130 | + |
| 131 | + __syncthreads(); |
| 132 | + |
| 133 | + // calculate the prefix sum for each expert |
| 134 | + // total number of tokens per expert is stored in token_counts(n_shards, _) |
| 135 | + if (curr_expert < n_experts) { |
| 136 | + token_counts(0, curr_expert) = 0; |
| 137 | + for (int i = 1; i <= n_shards; ++i) { |
| 138 | + token_counts(i, curr_expert) += token_counts(i - 1, curr_expert); |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + __syncthreads(); |
| 143 | + |
| 144 | + if (tid == 0) { |
| 145 | + // caluculate cumulative sum for each expert |
| 146 | + cu_sum[0] = 0; |
| 147 | + for (int e_idx = 1; e_idx <= n_experts; ++e_idx) { |
| 148 | + cu_sum[e_idx] = |
| 149 | + cu_sum[e_idx - 1] + |
| 150 | + cute::round_up(token_counts(n_shards, e_idx - 1), block_size); |
| 151 | + } |
| 152 | + *n_padded_tokens = cu_sum[n_experts]; |
| 153 | + } |
| 154 | + |
| 155 | + __syncthreads(); |
| 156 | + |
| 157 | + // each thread fills the expert id for each token block |
| 158 | + if (curr_expert < n_experts) { |
| 159 | + for (int i = cu_sum[curr_expert]; i < cu_sum[curr_expert + 1]; |
| 160 | + i += block_size) { |
| 161 | + expert_ids[i / block_size] = curr_expert; |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + for (int i = tid; i < n_tokens; i += stride) { |
| 166 | + const auto e_idx = topk_ids[i]; |
| 167 | + const auto idx = token_counts(curr_shard, e_idx)++; |
| 168 | + sorted_token_idxes[cu_sum[e_idx] + idx] = i; |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +// reduce along topk dimension, assuming contiguous memory |
| 173 | +template <typename scalar_t, int TOPK> |
| 174 | +__global__ void topk_sum_kernel( |
| 175 | + scalar_t* __restrict__ out, // [n_tokens, dim] |
| 176 | + const scalar_t* __restrict__ input, // [n_tokens, topk, dim] |
| 177 | + int64_t dim) { |
| 178 | + // one block per token |
| 179 | + const int64_t t_idx = blockIdx.x; |
| 180 | + for (int64_t i = threadIdx.x; i < dim; i += blockDim.x) { |
| 181 | + scalar_t sum = 0.0; |
| 182 | + CUTE_UNROLL |
| 183 | + for (int k = 0; k < TOPK; ++k) { |
| 184 | + sum += input[(t_idx * TOPK * dim) + (k * dim) + i]; |
| 185 | + } |
| 186 | + out[(t_idx * dim) + i] = sum; |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +} // namespace |
| 191 | + |
| 192 | +void permute_align_block( |
| 193 | + torch::Tensor topk_ids, // [n_tokens, topk] |
| 194 | + int64_t n_experts, |
| 195 | + int64_t block_size, |
| 196 | + torch::Tensor sorted_token_idxes, // [n_padded_permuted_tokens+] |
| 197 | + torch::Tensor experts_ids, // [n_blocks+] |
| 198 | + torch::Tensor n_padded_tokens, // [1] |
| 199 | + torch::Tensor cu_sum // [n_experts+1] |
| 200 | +) { |
| 201 | + const auto n_flatten_tokens = topk_ids.numel(); |
| 202 | + auto* stream = at::cuda::getCurrentCUDAStream().stream(); |
| 203 | + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "align_block_kernel", [&] { |
| 204 | + if (n_flatten_tokens <= 1024 && n_experts <= 64) { |
| 205 | + const int threads = std::max<int>(n_experts, WARP_SIZE); |
| 206 | + const int smem_size = |
| 207 | + ((threads + 1) * n_experts + (n_experts + 1)) * sizeof(int); |
| 208 | + |
| 209 | + align_block_kernel<scalar_t><<<1, threads, smem_size, stream>>>( |
| 210 | + topk_ids.data_ptr<scalar_t>(), |
| 211 | + sorted_token_idxes.data_ptr<scalar_t>(), |
| 212 | + experts_ids.data_ptr<scalar_t>(), |
| 213 | + n_padded_tokens.data_ptr<scalar_t>(), |
| 214 | + n_experts, |
| 215 | + block_size, |
| 216 | + n_flatten_tokens); |
| 217 | + } else { |
| 218 | + // each thread handles one expert |
| 219 | + // assert(n_experts <= 1024); |
| 220 | + size_t smem_size = 1024 * sizeof(int); |
| 221 | + cusum_kernel<scalar_t> |
| 222 | + <<<1, 1024, smem_size, stream>>>(topk_ids.data_ptr<scalar_t>(), |
| 223 | + experts_ids.data_ptr<scalar_t>(), |
| 224 | + n_padded_tokens.data_ptr<scalar_t>(), |
| 225 | + n_experts, |
| 226 | + block_size, |
| 227 | + n_flatten_tokens, |
| 228 | + cu_sum.data_ptr<scalar_t>()); |
| 229 | + |
| 230 | + constexpr int threads = 256; |
| 231 | + int n_blocks = cute::ceil_div(n_flatten_tokens, threads); |
| 232 | + n_blocks = std::min(n_blocks, 65535); |
| 233 | + row_id_map_kernel<scalar_t><<<n_blocks, threads, 0, stream>>>( |
| 234 | + topk_ids.data_ptr<scalar_t>(), |
| 235 | + sorted_token_idxes.data_ptr<scalar_t>(), |
| 236 | + cu_sum.data_ptr<scalar_t>(), |
| 237 | + n_flatten_tokens); |
| 238 | + } |
| 239 | + }); |
| 240 | +} |
| 241 | + |
| 242 | +void sum_out(const torch::Tensor& input, // [n_tokens, topk, dim] |
| 243 | + torch::Tensor& output) // [n_tokens, dim] |
| 244 | +{ |
| 245 | + const auto n_tokens = input.size(0); |
| 246 | + const auto topk = input.size(1); |
| 247 | + const auto dim = input.size(2); |
| 248 | + |
| 249 | + // one block per token |
| 250 | + const auto threads = std::min<int>(dim, 1024); |
| 251 | + auto* stream = at::cuda::getCurrentCUDAStream().stream(); |
| 252 | + |
| 253 | +#define DISPATCH_TOPK_SUM_KERNEL_CASE(TOPK) \ |
| 254 | + case TOPK: { \ |
| 255 | + DISPATCH_FLOATING_TYPES(input.scalar_type(), "sum_kernel", [&] { \ |
| 256 | + topk_sum_kernel<scalar_t, TOPK><<<n_tokens, threads, 0, stream>>>( \ |
| 257 | + output.data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), dim); \ |
| 258 | + }); \ |
| 259 | + break; \ |
| 260 | + } |
| 261 | + |
| 262 | + switch (topk) { |
| 263 | + DISPATCH_TOPK_SUM_KERNEL_CASE(2); |
| 264 | + DISPATCH_TOPK_SUM_KERNEL_CASE(3); |
| 265 | + DISPATCH_TOPK_SUM_KERNEL_CASE(4); |
| 266 | + DISPATCH_TOPK_SUM_KERNEL_CASE(8); |
| 267 | + default: |
| 268 | + // use torch::sum_out for other cases |
| 269 | + torch::sum_out(output, input, /*dim=*/1); |
| 270 | + break; |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +} // namespace llm::kernel::moe |
0 commit comments