Skip to content

Commit cd262df

Browse files
committed
clean up comments and prints
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 622b68a commit cd262df

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
139139
return Fp8MoeBackend.FLASHINFER_TRTLLM
140140
else:
141141
if block_quant:
142-
raise ValueError("FlashInfer FP8 MoE CUTLASS backend does not support block quantization")
142+
raise ValueError("FlashInfer FP8 MoE throughput backend does not "
143+
"support block quantization. Please use "
144+
"VLLM_FLASHINFER_MOE_BACKEND=latency "
145+
"instead.")
143146
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
144147
return Fp8MoeBackend.FLASHINFER_CUTLASS
145148

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def is_layer_excluded(self, prefix: str) -> bool:
224224
def get_quant_method(
225225
self, layer: torch.nn.Module, prefix: str
226226
) -> Optional["QuantizeMethodBase"]:
227-
from vllm.attention.layer import Attention # Avoid circular import
227+
from vllm.attention.layer import Attention, MLAAttention # Avoid circular import
228228

229229
if isinstance(layer, LinearBase):
230230
if self.is_layer_excluded(prefix):
@@ -233,7 +233,7 @@ def get_quant_method(
233233
if "vision_tower" in prefix or "vision_model" in prefix:
234234
return UnquantizedLinearMethod()
235235
return ModelOptFp8LinearMethod(self)
236-
elif isinstance(layer, Attention):
236+
elif isinstance(layer, Attention) or isinstance(layer, MLAAttention):
237237
return ModelOptFp8KVCacheMethod(self)
238238
elif isinstance(layer, FusedMoE):
239239
return ModelOptFp8MoEMethod(self, layer)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,14 +1021,14 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
10211021
return None
10221022
return remapped_name
10231023

1024-
# if any("mla_attn" in key for key in params_dict):
1025-
# attn_str = "mla_attn.mla_attn"
1026-
# logger.debug_once(
1027-
# f"Found mla_attn with k_scale and v_scale in "
1028-
# f"the checkpoint, using {attn_str} as attn_str"
1029-
# )
1030-
# else:
1031-
attn_str = "attn"
1024+
if any("mla_attn" in key for key in params_dict):
1025+
attn_str = "mla_attn.mla_attn"
1026+
logger.debug_once(
1027+
f"Found mla_attn with k_scale and v_scale in "
1028+
f"the checkpoint, using {attn_str} as attn_str"
1029+
)
1030+
else:
1031+
attn_str = "attn"
10321032
# Define scale name mapping patterns in order of precedence
10331033
scale_mapping_patterns = [
10341034
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
@@ -1055,14 +1055,13 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
10551055
if re.search(pattern, name):
10561056
remapped_name = re.sub(pattern, replacement, name)
10571057
if remapped_name not in params_dict:
1058-
# find the scale type in params_dict
1059-
params_scale_name = "<not found>"
10601058
scale_type = name.split(".")[-1]
1061-
print(params_dict.keys())
10621059
logger.warning_once(
1063-
f"Found {scale_type} in the checkpoint (e.g. {name}), but not found the remapped name in the model "
1064-
f" (e.g. {remapped_name}). {scale_type} is not loaded."
1065-
# f"Expected format is {params_scale_name} ", # noqa: E501
1060+
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
1061+
scale_type,
1062+
name,
1063+
remapped_name,
1064+
scale_type,
10661065
)
10671066
return None
10681067
return remapped_name

0 commit comments

Comments
 (0)