Skip to content

Commit 0dc9532

Browse files
youzhedianhongchaozou3519
authored
[BUGFIX ] fix undefined silu_and_mul_nvfp4_quant (vllm-project#23929)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: Richard Zou <zou3519@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: Richard Zou <zou3519@gmail.com> Co-authored-by: Richard Zou <zou3519@users.noreply.github.com>
1 parent 72a6913 commit 0dc9532

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

csrc/ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
130130
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
131131
torch::Tensor& scale);
132132

133-
#ifndef USE_ROCM
134-
133+
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
134+
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
135135
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
136136
torch::Tensor& output_block_scale,
137137
torch::Tensor& input,

csrc/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
115115
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
116116
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
117117

118-
#ifndef USE_ROCM
118+
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
119+
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
119120
ops.def(
120121
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
121122
"Tensor input, Tensor input_global_scale) -> ()");

vllm/compilation/fix_functionalization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def __call__(self, graph: torch.fx.Graph):
9797
node,
9898
mutated_args,
9999
args=('result', 'input', 'scale'))
100-
elif at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
100+
elif hasattr(
101+
torch.ops._C, "silu_and_mul_nvfp4_quant"
102+
) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
101103
mutated_args = {1: 'result', 2: 'result_block_scale'}
102104
self.defunctionalize(graph,
103105
node,

0 commit comments

Comments
 (0)