Skip to content

Commit 0c22d3c

Browse files
committed
Address comment
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 326fa96 commit 0c22d3c

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131

3232
namespace vllm {
3333

34-
#define round_up(x, y) ((x + y - 1) / y * y)
34+
template<typename Int>
35+
__host__ __device__ inline Int round_up(Int x, Int y)
36+
{
37+
static_assert(std::is_integral_v<Int>, "round_up argument must be integral type");
38+
return (x + y - 1) / y * y;
39+
}
40+
3541
// Use UE4M3 by default.
3642
template <class Type, bool UE8M0_SF = false>
3743
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@@ -43,14 +49,14 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4349
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
4450
"Vec size is not matched.");
4551

46-
int sf_m = round_up(numRows, 128);
52+
int sf_m = round_up<int>(numRows, 128);
4753
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
48-
int sf_n_uint32 = round_up(sf_n_unpadded, 4) / 4;
54+
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
4955
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
5056
// Each thread writes 4 uint32_t elements.
51-
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_uint32;
57+
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
5258
col += blockDim.x * 4) {
53-
SFout[row * sf_n_uint32 + col] = 0x00000000;
59+
SFout[row * sf_n_int + col] = 0x00;
5460
}
5561
}
5662

0 commit comments

Comments
 (0)