Skip to content

Commit 1064bed

Browse files
committed
Add parallelization and fix codex review comment
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 34d8036 commit 1064bed

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4343
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
4444
"Vec size is not matched.");
4545

46-
if (blockIdx.x == 0 && threadIdx.x == 0) {
47-
int sf_m = round_up(numRows, 128);
48-
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
49-
int sf_n = round_up(sf_n_unpadded, 4) / 4;
50-
for (int row = numRows; row < sf_m; row += 1) {
51-
for (int col = sf_n_unpadded; col < sf_n; col += 1) {
52-
SFout[row * sf_n + col] = 0x00;
53-
}
46+
int sf_m = round_up(numRows, 128);
47+
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
48+
int sf_n_uint32 = round_up(sf_n_unpadded, 4) / 4;
49+
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
50+
for (int col = (sf_n_unpadded + 3) / 4 + threadIdx.x;
51+
col < sf_n_uint32;
52+
col += blockDim.x) {
53+
SFout[row * sf_n_uint32 + col] = 0x00000000;
5454
}
5555
}
5656

vllm/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def get_vllm_port() -> int | None:
11761176
# - "latency":
11771177
# Uses TensorRT-LLM kernels optimized for low-latency inference.
11781178
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
1179-
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
1179+
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
11801180
),
11811181
# Control the maximum number of tokens per expert supported by the
11821182
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for

0 commit comments

Comments
 (0)