Skip to content

Commit 6930085

Browse files
feat: MxInt4 x Bf16 TRT-LLM Gen MoE support (#2159)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Add the MxInt4 x BF16 TRTLLM GEN moe ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * MXInt4 MoE inference path with public API and test coverage; MXInt4 + BF16 supported end-to-end. * Exposed new MXInt4 op and helper in the package exports. * **Refactor** * Block-scale/interleave routines generalized to support uint8 and bfloat16 inputs and outputs. * GEMM/BatchedGemm configs now include an element-wise activation option and are arch-aware (CUDA arch). * **Tests** * Added MXInt4 quantization and runtime tests for MoE. * **Chores** * Updated packaged artifact path/checksum. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
1 parent 4db4ac0 commit 6930085

File tree

21 files changed

+1167
-172
lines changed

21 files changed

+1167
-172
lines changed

β€Žcsrc/nv_internal/cpp/kernels/quantization.cuβ€Ž

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,14 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
240240
}
241241
}
242242

243+
template <typename T>
243244
__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded,
244-
int numCols, int numColsPadded, uint8_t const* SFIn,
245-
uint8_t* SFOutput) {
245+
int numCols, int numColsPadded, T const* SFIn,
246+
T* SFOutput) {
246247
for (int rowIdx = blockIdx.x; rowIdx < numRowsPadded; rowIdx += gridDim.x) {
247248
for (int batchIdx = 0; batchIdx < numBatches; batchIdx++) {
248249
for (int colIdx = threadIdx.x; colIdx < numColsPadded; colIdx += blockDim.x) {
249-
uint8_t sf = 0;
250+
T sf = 0;
250251
if (rowIdx < numRows && colIdx < numCols) {
251252
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
252253
sf = SFIn[inOffset];
@@ -287,19 +288,29 @@ __global__ void block_scale_interleave_reverse_kernel(int numBatches, int numRow
287288
}
288289

289290
// This is intended for weight loading, so m and n are large, b <= 256
290-
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
291-
uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount,
292-
cudaStream_t stream) {
291+
template <typename T>
292+
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn,
293+
T* SFOutput, int multiProcessorCount, cudaStream_t stream) {
293294
// Each thread reads 1 int8 value
294295
dim3 block(std::min(n_padded, 1024));
295296
// Get number of blocks per SM (assume we can fully utilize the SM).
296297
int const numBlocksPerSM = std::max(1u, 4096u / block.x);
297298
dim3 grid(std::min(m_padded, multiProcessorCount * numBlocksPerSM));
298299

299-
block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn,
300-
SFOutput);
300+
block_scale_interleave_kernel<T>
301+
<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn, SFOutput);
301302
}
302303

304+
// Explicit template instantiations for the types used by other compilation units
305+
template void invokeBlockScaleInterleave<uint8_t>(int b, int m, int m_padded, int n, int n_padded,
306+
uint8_t const* SFIn, uint8_t* SFOutput,
307+
int multiProcessorCount, cudaStream_t stream);
308+
template void invokeBlockScaleInterleave<__nv_bfloat16>(int b, int m, int m_padded, int n,
309+
int n_padded, __nv_bfloat16 const* SFIn,
310+
__nv_bfloat16* SFOutput,
311+
int multiProcessorCount,
312+
cudaStream_t stream);
313+
303314
// This is intended for weight loading, so m and n are large, b <= 256
304315
void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput,
305316
int multiProcessorCount, cudaStream_t stream) {

β€Žcsrc/nv_internal/tensorrt_llm/kernels/quantization.hβ€Ž

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* i
6767
void* input_global_scale, void* mask, bool use_silu_and_mul,
6868
int m_topk, int k, int n_experts, cudaStream_t stream);
6969

70-
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
71-
uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount,
72-
cudaStream_t stream = 0);
70+
template <typename T>
71+
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn,
72+
T* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
7373

7474
void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput,
7575
int multiProcessorCount, cudaStream_t stream = 0);

β€Žcsrc/nv_internal/tensorrt_llm/thop/fp4Op.cppβ€Ž

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,41 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn,
137137
}
138138
}
139139

140+
template <typename T>
141+
void blockScaleInterleaveHost(TensorView blockScale, TensorView interleavedBlockScale) {
142+
auto blockScaleShape = blockScale.sizes();
143+
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
144+
auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0];
145+
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
146+
147+
auto expert_out_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols);
148+
auto rows_padded = PadUpFn(rows, 128);
149+
auto cols_padded = PadUpFn(cols, 4);
150+
151+
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
152+
T* interleavedBlockScalePtr =
153+
static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
154+
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
155+
auto globalRowIdx = eIdx * rows + rIdx;
156+
T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
157+
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
158+
T sf_ori = 0;
159+
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
160+
sf_ori = blockScalePtr[cIdx];
161+
}
162+
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
163+
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
164+
interleavedBlockScalePtr[sf_index] = sf_ori;
165+
}
166+
}
167+
}
168+
}
169+
170+
template void blockScaleInterleaveHost<uint8_t>(TensorView blockScale,
171+
TensorView interleavedBlockScale);
172+
template void blockScaleInterleaveHost<__nv_bfloat16>(TensorView blockScale,
173+
TensorView interleavedBlockScale);
174+
140175
// Interleave (and possibly pad) the weights block scaling factor.
141176
// blockScale: [num_experts, rows, cols] or [rows, cols]
142177
// Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4)
@@ -148,7 +183,8 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal
148183
CHECK_CPU(blockScale);
149184
}
150185
CHECK_CONTIGUOUS(blockScale);
151-
CHECK_INPUT_TYPE(blockScale, dl_uint8);
186+
TVM_FFI_ICHECK(blockScale.dtype() == dl_uint8 || blockScale.dtype() == dl_bfloat16)
187+
<< "Block Scale must be uint8 or bfloat16.";
152188
auto blockScaleShape = blockScale.sizes();
153189
TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3)
154190
<< "Block Scale should be 2D or 3D tensor.";
@@ -166,27 +202,28 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal
166202
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
167203
const cudaStream_t stream = get_stream(blockScale.device());
168204

169-
tensorrt_llm::kernels::invokeBlockScaleInterleave(
170-
num_experts, rows, rows_padded, cols, cols_padded,
171-
static_cast<uint8_t*>(blockScale.data_ptr()),
172-
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
205+
if (blockScale.dtype() == dl_uint8) {
206+
tensorrt_llm::kernels::invokeBlockScaleInterleave(
207+
num_experts, rows, rows_padded, cols, cols_padded,
208+
static_cast<uint8_t*>(blockScale.data_ptr()),
209+
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
210+
} else if (blockScale.dtype() == dl_bfloat16) {
211+
tensorrt_llm::kernels::invokeBlockScaleInterleave(
212+
num_experts, rows, rows_padded, cols, cols_padded,
213+
static_cast<__nv_bfloat16*>(blockScale.data_ptr()),
214+
static_cast<__nv_bfloat16*>(interleavedBlockScale.data_ptr()), smCount, stream);
215+
} else {
216+
TVM_FFI_LOG_AND_THROW(NotImplementedError)
217+
<< "block_scale_interleave only supports uint8 and bfloat16.";
218+
}
173219
} else {
174-
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
175-
uint8_t* interleavedBlockScalePtr =
176-
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
177-
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
178-
auto globalRowIdx = eIdx * rows + rIdx;
179-
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr()) + globalRowIdx * cols;
180-
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
181-
uint8_t sf_ori = 0;
182-
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
183-
sf_ori = blockScalePtr[cIdx];
184-
}
185-
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
186-
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
187-
interleavedBlockScalePtr[sf_index] = sf_ori;
188-
}
189-
}
220+
if (blockScale.dtype() == dl_uint8) {
221+
blockScaleInterleaveHost<uint8_t>(blockScale, interleavedBlockScale);
222+
} else if (blockScale.dtype() == dl_bfloat16) {
223+
blockScaleInterleaveHost<__nv_bfloat16>(blockScale, interleavedBlockScale);
224+
} else {
225+
TVM_FFI_LOG_AND_THROW(NotImplementedError)
226+
<< "blockScaleInterleaveHost only supports uint8 and bfloat16.";
190227
}
191228
}
192229
}

0 commit comments

Comments
Β (0)