3131
3232namespace vllm {
3333
34- #define round_up (x, y ) ((x + y - 1 ) / y * y)
34+ template <typename Int>
35+ __host__ __device__ inline Int round_up (Int x, Int y)
36+ {
37+ static_assert (std::is_integral_v<Int>, " round_up argument must be integral type" );
38+ return (x + y - 1 ) / y * y;
39+ }
40+
3541// Use UE4M3 by default.
3642template <class Type , bool UE8M0_SF = false >
3743__global__ void __launch_bounds__ (512 , VLLM_BLOCKS_PER_SM(512 ))
@@ -43,14 +49,14 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4349 static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
4450 " Vec size is not matched." );
4551
46- int sf_m = round_up (numRows, 128 );
52+ int sf_m = round_up< int > (numRows, 128 );
4753 int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
48- int sf_n_uint32 = round_up (sf_n_unpadded, 4 ) / 4 ;
54+ int sf_n_int = round_up< int > (sf_n_unpadded, 4 ) / 4 ;
4955 for (int row = numRows + blockIdx .x ; row < sf_m; row += gridDim .x ) {
5056 // Each thread writes 4 uint32_t elements.
51- for (int col = sf_n_unpadded + threadIdx .x * 4 ; col < sf_n_uint32 ;
57+ for (int col = sf_n_unpadded + threadIdx .x * 4 ; col < sf_n_int ;
5258 col += blockDim .x * 4 ) {
53- SFout[row * sf_n_uint32 + col] = 0x00000000 ;
59+ SFout[row * sf_n_int + col] = 0x00 ;
5460 }
5561 }
5662
0 commit comments