Skip to content

Commit 90b139c

Browse files
authored
Enable Fbgemm NVFP4 on Dense models (vllm-project#25609)
Signed-off-by: Saman Keon <samanamp@outlook.com>
1 parent 4492e3a commit 90b139c

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

benchmarks/kernels/bench_nvfp4_gemm.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import copy
55
import itertools
6+
import os
67

78
import torch
89
from weight_shapes import WEIGHT_SHAPES
@@ -23,21 +24,45 @@
2324
"torch-bf16": dict(enabled=True),
2425
"nvfp4": dict(no_a_quant=False, enabled=True),
2526
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
27+
"fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True),
28+
"fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True),
2629
}
2730

31+
_needs_fbgemm = any(
32+
v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False)
33+
)
34+
if _needs_fbgemm:
35+
try:
36+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
37+
triton_scale_nvfp4_quant,
38+
)
39+
except ImportError:
40+
print(
41+
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
42+
"These providers will be skipped. Please install fbgemm_gpu with: "
43+
"'pip install fbgemm-gpu-genai' to run them."
44+
)
45+
# Disable FBGEMM providers so the benchmark can run.
46+
for cfg in PROVIDER_CFGS.values():
47+
if cfg.get("fbgemm"):
48+
cfg["enabled"] = False
49+
2850
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
2951

3052

31-
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
53+
def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg):
3254
# Compute global scale for weight
3355
b_amax = torch.abs(b).max().to(torch.float32)
3456
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
35-
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
57+
if "fbgemm" in cfg and cfg["fbgemm"]:
58+
b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale)
59+
else:
60+
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
3661
return b_fp4, scale_b_fp4, b_global_scale
3762

3863

3964
def build_nvfp4_runner(cfg, a, b, dtype, device):
40-
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
65+
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg)
4166

4267
# Compute global scale for activation
4368
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
@@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device):
4671

4772
# Alpha for the GEMM operation
4873
alpha = 1.0 / (a_global_scale * b_global_scale)
74+
if "fbgemm" in cfg and cfg["fbgemm"]:
75+
if cfg["no_a_quant"]:
76+
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
77+
78+
def run():
79+
return torch.ops.fbgemm.f4f4bf16(
80+
a_fp4,
81+
b_fp4,
82+
scale_a_fp4,
83+
scale_b_fp4,
84+
global_scale=alpha,
85+
use_mx=False,
86+
)
87+
88+
return run
89+
else:
90+
91+
def run():
92+
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
93+
return torch.ops.fbgemm.f4f4bf16(
94+
a_fp4,
95+
b_fp4,
96+
scale_a_fp4,
97+
scale_b_fp4,
98+
global_scale=alpha,
99+
use_mx=False,
100+
)
101+
102+
return run
49103

50104
if cfg["no_a_quant"]:
51105
# Pre-quantize activation
@@ -130,10 +184,13 @@ def prepare_shapes(args):
130184

131185
for K, N, model in prepare_shapes(args):
132186
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
187+
save_dir = f"bench_nvfp4_res_n{N}_k{K}"
188+
os.makedirs(save_dir, exist_ok=True)
189+
133190
benchmark.run(
134191
print_data=True,
135192
show_plots=True,
136-
save_path=f"bench_nvfp4_res_n{N}_k{K}",
193+
save_path=save_dir,
137194
N=N,
138195
K=K,
139196
)

vllm/envs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
202202
VLLM_USE_NCCL_SYMM_MEM: bool = False
203203
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
204+
VLLM_USE_FBGEMM: bool = False
204205

205206

206207
def get_default_cache_root():
@@ -1452,7 +1453,8 @@ def get_vllm_port() -> Optional[int]:
14521453
# NCCL header path
14531454
"VLLM_NCCL_INCLUDE_PATH":
14541455
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
1455-
1456+
# Flag to enable FBGemm kernels on model execution
1457+
"VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))),
14561458
}
14571459

14581460
# --8<-- [end:env-vars-definition]
@@ -1548,6 +1550,7 @@ def compute_hash() -> str:
15481550
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
15491551
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
15501552
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
1553+
"VLLM_USE_FBGEMM",
15511554
]
15521555
for key in environment_variables_to_hash:
15531556
# if this goes out of sync with environment_variables,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,20 @@ def __init__(self):
3030
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
3131
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
3232
self.backend = "flashinfer-trtllm"
33+
logger.info_once("Using flashinfer-trtllm for FP4")
34+
elif envs.VLLM_USE_FBGEMM:
35+
self.backend = "fbgemm"
36+
try:
37+
import fbgemm_gpu # noqa: F401
38+
except ImportError as exc:
39+
raise ImportError(
40+
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
41+
"Please install with: pip install fbgemm-gpu-genai"
42+
) from exc
43+
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
3344
elif has_flashinfer():
3445
self.backend = "flashinfer-cutlass"
46+
logger.info_once("Using flashinfer-cutlass for FP4")
3547
else:
3648
self.backend = "cutlass"
3749
self.group_size = 16
@@ -116,6 +128,9 @@ def process_weights_after_loading(self, layer) -> None:
116128
layer.weight_packed = Parameter(weight, requires_grad=False)
117129
else:
118130
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
131+
if self.backend == "fbgemm":
132+
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(
133+
torch.uint8)
119134
layer.weight_scale = Parameter(swizzled_weight_scale,
120135
requires_grad=False)
121136
layer.weight_packed = Parameter(layer.weight_packed.data,
@@ -153,6 +168,15 @@ def apply_weights(self,
153168
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
154169
elif self.backend == "flashinfer-cutlass":
155170
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
171+
elif self.backend == "fbgemm":
172+
out = torch.ops.fbgemm.f4f4bf16(
173+
x_fp4,
174+
layer.weight_packed,
175+
x_blockscale.view(-1).view(torch.uint8),
176+
layer.weight_scale,
177+
layer.alpha,
178+
use_mx=False,
179+
).to(output_dtype)
156180
else:
157181
out = cutlass_scaled_fp4_mm(*mm_args)
158182

0 commit comments

Comments
 (0)