Skip to content

Commit 326fa96

Browse files
committed
Fix col start
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 53bea8c commit 326fa96

File tree

7 files changed

+47
-35
lines changed

7 files changed

+47
-35
lines changed

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,17 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4343
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
4444
"Vec size is not matched.");
4545

46-
if (blockIdx.x == 0 && threadIdx.x == 0) {
47-
int sf_m = round_up(numRows, 128);
48-
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
49-
int sf_n = round_up(sf_n_unpadded, 4) / 4;
50-
for(int row = numRows; row < sf_m; row += 1) {
51-
for(int col = sf_n_unpadded; col < sf_n; col +=1) {
52-
SFout[row * sf_n + col] = 0x00;
53-
}
46+
int sf_m = round_up(numRows, 128);
47+
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
48+
int sf_n_uint32 = round_up(sf_n_unpadded, 4) / 4;
49+
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
50+
// Each thread writes 4 uint32_t elements.
51+
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_uint32;
52+
col += blockDim.x * 4) {
53+
SFout[row * sf_n_uint32 + col] = 0x00000000;
5454
}
55-
}
56-
55+
}
56+
5757
// Get the global scaling factor, which will be applied to the SF.
5858
// Note SFScale is the same as next GEMM's alpha, which is
5959
// (448.f / (Alpha_A / 6.f)).

tests/kernels/quantization/test_nvfp4_quant.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,5 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
170170
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
171171
scale_ans = recover_swizzled_scales(out_scale, m, n)
172172
out_ans = cast_from_fp4(out, m, n)
173-
print(f"out_ans: {out_ans}")
174-
print(f"out_ref: {out_ref}")
175-
print(f"scale_ans: {scale_ans}")
176-
print(f"scale_ref: {scale_ref}")
177173
torch.testing.assert_close(out_ans, out_ref)
178174
torch.testing.assert_close(scale_ans, scale_ref)

vllm/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,7 @@ def get_vllm_port() -> int | None:
12081208
# - "latency":
12091209
# Uses TensorRT-LLM kernels optimized for low-latency inference.
12101210
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
1211-
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
1211+
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
12121212
),
12131213
# Control the maximum number of tokens per expert supported by the
12141214
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
@@ -1313,7 +1313,7 @@ def get_vllm_port() -> int | None:
13131313
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
13141314
"VLLM_NVFP4_GEMM_BACKEND",
13151315
None,
1316-
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
1316+
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"],
13171317
),
13181318
# Controls garbage collection during CUDA graph capture.
13191319
# If set to 0 (default), enables GC freezing to speed up capture time.

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def __init__(self):
5050
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
5151
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
5252
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
53+
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
54+
self.backend = "cutlass"
55+
assert cutlass_fp4_supported(), "Cutlass is required for {self.backend}"
5356

5457
if self.backend == "none":
5558
raise ValueError(

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,12 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
139139
return Fp8MoeBackend.FLASHINFER_TRTLLM
140140
else:
141141
if block_quant:
142-
raise ValueError("FlashInfer FP8 MoE CUTLASS backend does not support block quantization")
142+
raise ValueError(
143+
"FlashInfer FP8 MoE throughput backend does not "
144+
"support block quantization. Please use "
145+
"VLLM_FLASHINFER_MOE_BACKEND=latency "
146+
"instead."
147+
)
143148
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
144149
return Fp8MoeBackend.FLASHINFER_CUTLASS
145150

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,10 @@ def is_layer_excluded(self, prefix: str) -> bool:
224224
def get_quant_method(
225225
self, layer: torch.nn.Module, prefix: str
226226
) -> Optional["QuantizeMethodBase"]:
227-
from vllm.attention.layer import Attention # Avoid circular import
227+
from vllm.attention.layer import ( # Avoid circular import
228+
Attention,
229+
MLAAttention,
230+
)
228231

229232
if isinstance(layer, LinearBase):
230233
if self.is_layer_excluded(prefix):
@@ -233,7 +236,7 @@ def get_quant_method(
233236
if "vision_tower" in prefix or "vision_model" in prefix:
234237
return UnquantizedLinearMethod()
235238
return ModelOptFp8LinearMethod(self)
236-
elif isinstance(layer, Attention):
239+
elif isinstance(layer, (Attention, MLAAttention)):
237240
return ModelOptFp8KVCacheMethod(self)
238241
elif isinstance(layer, FusedMoE):
239242
return ModelOptFp8MoEMethod(self, layer)
@@ -905,7 +908,10 @@ def is_layer_excluded(self, prefix: str) -> bool:
905908
def get_quant_method(
906909
self, layer: torch.nn.Module, prefix: str
907910
) -> Optional["QuantizeMethodBase"]:
908-
from vllm.attention.layer import Attention # Avoid circular import
911+
from vllm.attention.layer import ( # Avoid circular import
912+
Attention,
913+
MLAAttention,
914+
)
909915

910916
skip_layer = self.is_layer_excluded(prefix)
911917
if isinstance(layer, LinearBase):
@@ -915,7 +921,7 @@ def get_quant_method(
915921
if "vision_tower" in prefix or "vision_model" in prefix:
916922
return UnquantizedLinearMethod()
917923
return ModelOptNvFp4LinearMethod(self)
918-
elif isinstance(layer, Attention):
924+
elif isinstance(layer, (Attention, MLAAttention)):
919925
return ModelOptFp8KVCacheMethod(self)
920926
elif isinstance(layer, FusedMoE):
921927
if skip_layer:
@@ -958,6 +964,9 @@ def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
958964
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
959965
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
960966
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
967+
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
968+
self.backend = "cutlass"
969+
assert cutlass_fp4_supported(), "Cutlass is required for {self.backend}"
961970

962971
if self.backend == "none":
963972
raise ValueError(

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,14 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
10341034
return None
10351035
return remapped_name
10361036

1037-
# if any("mla_attn" in key for key in params_dict):
1038-
# attn_str = "mla_attn.mla_attn"
1039-
# logger.debug_once(
1040-
# f"Found mla_attn with k_scale and v_scale in "
1041-
# f"the checkpoint, using {attn_str} as attn_str"
1042-
# )
1043-
# else:
1044-
attn_str = "attn"
1037+
if any("mla_attn" in key for key in params_dict):
1038+
attn_str = "mla_attn.mla_attn"
1039+
logger.debug_once(
1040+
f"Found mla_attn with k_scale and v_scale in "
1041+
f"the checkpoint, using {attn_str} as attn_str"
1042+
)
1043+
else:
1044+
attn_str = "attn"
10451045
# Define scale name mapping patterns in order of precedence
10461046
scale_mapping_patterns = [
10471047
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
@@ -1068,14 +1068,13 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
10681068
if re.search(pattern, name):
10691069
remapped_name = re.sub(pattern, replacement, name)
10701070
if remapped_name not in params_dict:
1071-
# find the scale type in params_dict
1072-
params_scale_name = "<not found>"
10731071
scale_type = name.split(".")[-1]
1074-
print(params_dict.keys())
10751072
logger.warning_once(
1076-
f"Found {scale_type} in the checkpoint (e.g. {name}), but not found the remapped name in the model "
1077-
f" (e.g. {remapped_name}). {scale_type} is not loaded."
1078-
# f"Expected format is {params_scale_name} ", # noqa: E501
1073+
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
1074+
scale_type,
1075+
name,
1076+
remapped_name,
1077+
scale_type,
10791078
)
10801079
return None
10811080
return remapped_name

0 commit comments

Comments
 (0)