Skip to content

Commit 53bea8c

Browse files
committed
Misc fixes for FP4 MOE and Quant
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 3758757 commit 53bea8c

File tree

6 files changed

+36
-19
lines changed

6 files changed

+36
-19
lines changed

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
namespace vllm {
3333

34+
#define round_up(x, y) ((x + y - 1) / y * y)
3435
// Use UE4M3 by default.
3536
template <class Type, bool UE8M0_SF = false>
3637
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@@ -42,10 +43,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4243
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
4344
"Vec size is not matched.");
4445

46+
if (blockIdx.x == 0 && threadIdx.x == 0) {
47+
int sf_m = round_up(numRows, 128);
48+
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
49+
int sf_n = round_up(sf_n_unpadded, 4) / 4;
50+
for(int row = numRows; row < sf_m; row += 1) {
51+
for(int col = sf_n_unpadded; col < sf_n; col +=1) {
52+
SFout[row * sf_n + col] = 0x00;
53+
}
54+
}
55+
}
56+
4557
// Get the global scaling factor, which will be applied to the SF.
4658
// Note SFScale is the same as next GEMM's alpha, which is
4759
// (448.f / (Alpha_A / 6.f)).
48-
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
60+
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
4961

5062
// Input tensor row/col loops.
5163
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
@@ -64,7 +76,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
6476
rowIdx, colIdx, numCols, SFout);
6577

6678
out_pos =
67-
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
79+
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
6880
}
6981
}
7082
}

tests/kernels/quantization/test_nvfp4_quant.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,11 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
168168
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
169169

170170
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
171-
172171
scale_ans = recover_swizzled_scales(out_scale, m, n)
173172
out_ans = cast_from_fp4(out, m, n)
174-
173+
print(f"out_ans: {out_ans}")
174+
print(f"out_ref: {out_ref}")
175+
print(f"scale_ans: {scale_ans}")
176+
print(f"scale_ref: {scale_ref}")
175177
torch.testing.assert_close(out_ans, out_ref)
176178
torch.testing.assert_close(scale_ans, scale_ref)

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def scaled_fp4_quant(
13811381
rounded_m = round_up(m, 128)
13821382
scale_n = n // block_size
13831383
rounded_n = round_up(scale_n, 4)
1384-
output_scale = torch.zeros(
1384+
output_scale = torch.empty(
13851385
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
13861386
)
13871387

vllm/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@
154154
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
155155
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
156156
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
157-
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
157+
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
158158
VLLM_XGRAMMAR_CACHE_MB: int = 0
159159
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
160160
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
138138
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
139139
return Fp8MoeBackend.FLASHINFER_TRTLLM
140140
else:
141+
if block_quant:
142+
raise ValueError("FlashInfer FP8 MoE CUTLASS backend does not support block quantization")
141143
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
142144
return Fp8MoeBackend.FLASHINFER_CUTLASS
143145

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,14 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
10341034
return None
10351035
return remapped_name
10361036

1037-
if any("mla_attn" in key for key in params_dict):
1038-
attn_str = "mla_attn.mla_attn"
1039-
logger.debug_once(
1040-
f"Found mla_attn with k_scale and v_scale in "
1041-
f"the checkpoint, using {attn_str} as attn_str"
1042-
)
1043-
else:
1044-
attn_str = "attn"
1037+
# if any("mla_attn" in key for key in params_dict):
1038+
# attn_str = "mla_attn.mla_attn"
1039+
# logger.debug_once(
1040+
# f"Found mla_attn with k_scale and v_scale in "
1041+
# f"the checkpoint, using {attn_str} as attn_str"
1042+
# )
1043+
# else:
1044+
attn_str = "attn"
10451045
# Define scale name mapping patterns in order of precedence
10461046
scale_mapping_patterns = [
10471047
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
@@ -1068,13 +1068,14 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
10681068
if re.search(pattern, name):
10691069
remapped_name = re.sub(pattern, replacement, name)
10701070
if remapped_name not in params_dict:
1071+
# find the scale type in params_dict
1072+
params_scale_name = "<not found>"
10711073
scale_type = name.split(".")[-1]
1074+
print(params_dict.keys())
10721075
logger.warning_once(
1073-
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
1074-
scale_type,
1075-
name,
1076-
remapped_name,
1077-
scale_type,
1076+
f"Found {scale_type} in the checkpoint (e.g. {name}), but not found the remapped name in the model "
1077+
f" (e.g. {remapped_name}). {scale_type} is not loaded."
1078+
# f"Expected format is {params_scale_name} ", # noqa: E501
10781079
)
10791080
return None
10801081
return remapped_name

0 commit comments

Comments
 (0)