Skip to content

Commit e757a62

Browse files
authored
[Bug] Fix Cutlass Scaled MM Compilation Error (vllm-project#24887)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent aae725a commit e757a62

File tree

3 files changed

+53
-41
lines changed

3 files changed

+53
-41
lines changed

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
146146

147147
using ElementAB = typename Gemm::ElementAB;
148148
using ElementD = typename Gemm::ElementD;
149+
using ElementBlockScale = typename Gemm::ElementBlockScale;
149150

150151
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
151152

@@ -166,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
166167
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
167168
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
168169

169-
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
170-
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
171-
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
172-
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
170+
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
171+
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
172+
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
173+
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
173174

174-
auto mainloop_args = [&](){
175-
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
176-
if (swap_ab) {
177-
return typename GemmKernel::MainloopArguments{
178-
b_ptr, b_stride, a_ptr, a_stride,
179-
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
180-
};
181-
}
182-
else {
183-
return typename GemmKernel::MainloopArguments{
184-
a_ptr, a_stride, b_ptr, b_stride,
185-
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
186-
};
187-
}
188-
}();
175+
typename GemmKernel::MainloopArguments mainloop_args{};
176+
mainloop_args.layout_SFA = layout_SFA;
177+
mainloop_args.layout_SFB = layout_SFB;
178+
if (swap_ab) {
179+
mainloop_args.ptr_A = b_ptr;
180+
mainloop_args.dA = b_stride;
181+
mainloop_args.ptr_B = a_ptr;
182+
mainloop_args.dB = a_stride;
183+
mainloop_args.ptr_SFA = b_scales_ptr;
184+
mainloop_args.ptr_SFB = a_scales_ptr;
185+
} else {
186+
mainloop_args.ptr_A = a_ptr;
187+
mainloop_args.dA = a_stride;
188+
mainloop_args.ptr_B = b_ptr;
189+
mainloop_args.dB = b_stride;
190+
mainloop_args.ptr_SFA = a_scales_ptr;
191+
mainloop_args.ptr_SFB = b_scales_ptr;
192+
}
189193
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
190194

191195
auto c_ptr = static_cast<ElementD*>(out.data_ptr());

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
125125

126126
using ElementAB = typename Gemm::ElementAB;
127127
using ElementD = typename Gemm::ElementD;
128+
using ElementBlockScale = typename Gemm::ElementBlockScale;
128129

129130
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
130131

@@ -143,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
143144
LayoutSFB layout_SFB =
144145
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
145146

146-
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
147-
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
148-
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
149-
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
150-
151-
auto mainloop_args = [&](){
152-
return typename GemmKernel::MainloopArguments{
153-
a_ptr, a_stride, b_ptr, b_stride,
154-
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
155-
};
156-
}();
147+
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
148+
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
149+
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
150+
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
151+
152+
typename GemmKernel::MainloopArguments mainloop_args{};
153+
mainloop_args.ptr_A = a_ptr;
154+
mainloop_args.dA = a_stride;
155+
mainloop_args.ptr_B = b_ptr;
156+
mainloop_args.dB = b_stride;
157+
mainloop_args.ptr_SFA = a_scales_ptr;
158+
mainloop_args.layout_SFA = layout_SFA;
159+
mainloop_args.ptr_SFB = b_scales_ptr;
160+
mainloop_args.layout_SFB = layout_SFB;
157161
auto prob_shape = cute::make_shape(m, n, k, 1);
158162

159163
auto c_ptr = static_cast<ElementD*>(out.data_ptr());

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
115115

116116
using ElementAB = typename Gemm::ElementAB;
117117
using ElementD = typename Gemm::ElementD;
118+
using ElementBlockScale = typename Gemm::ElementBlockScale;
118119

119120
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
120121

@@ -135,17 +136,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
135136
LayoutSFB layout_SFB =
136137
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
137138

138-
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
139-
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
140-
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
141-
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
142-
143-
auto mainloop_args = [&](){
144-
return typename GemmKernel::MainloopArguments{
145-
a_ptr, a_stride, b_ptr, b_stride,
146-
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
147-
};
148-
}();
139+
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
140+
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
141+
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
142+
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
143+
144+
typename GemmKernel::MainloopArguments mainloop_args{};
145+
mainloop_args.ptr_A = a_ptr;
146+
mainloop_args.dA = a_stride;
147+
mainloop_args.ptr_B = b_ptr;
148+
mainloop_args.dB = b_stride;
149+
mainloop_args.ptr_SFA = a_scales_ptr;
150+
mainloop_args.layout_SFA = layout_SFA;
151+
mainloop_args.ptr_SFB = b_scales_ptr;
152+
mainloop_args.layout_SFB = layout_SFB;
149153
auto prob_shape = cute::make_shape(m, n, k, 1);
150154

151155
auto c_ptr = static_cast<ElementD*>(out.data_ptr());

0 commit comments

Comments
 (0)