Skip to content

Commit cdd7025

Browse files
authored
[kernel] Improve FP8 PTPC on Hopper for larger shapes (#28692)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent 0854248 commit cdd7025

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,26 @@ struct sm90_fp8_config_default {
116116
ClusterShape, KernelSchedule, EpilogueSchedule>>;
117117
};
118118

119+
template <typename InType, typename OutType, bool EnableBias>
120+
struct sm90_fp8_config_M8192_K6144 {
121+
// M >= 8192, K >= 6144
122+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
123+
using KernelSchedule =
124+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
125+
using EpilogueSchedule =
126+
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
127+
using TileShape = Shape<_256, _128, _128>;
128+
using ClusterShape = Shape<_2, _1, _1>;
129+
130+
using Cutlass3xGemm = conditional_t<
131+
EnableBias,
132+
cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogueBias,
133+
TileShape, ClusterShape, KernelSchedule,
134+
EpilogueSchedule>,
135+
cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
136+
ClusterShape, KernelSchedule, EpilogueSchedule>>;
137+
};
138+
119139
template <typename InType, typename OutType, bool EnableBias>
120140
struct sm90_fp8_config_M128 {
121141
// M in (64, 128]
@@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
273293
using Cutlass3xGemmDefault =
274294
typename sm90_fp8_config_default<InType, OutType,
275295
EnableBias>::Cutlass3xGemm;
296+
using Cutlass3xGemmM8192_K6144 =
297+
typename sm90_fp8_config_M8192_K6144<InType, OutType,
298+
EnableBias>::Cutlass3xGemm;
276299
using Cutlass3xGemmM128 =
277300
typename sm90_fp8_config_M128<InType, OutType, EnableBias>::Cutlass3xGemm;
278301

@@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
291314

292315
uint32_t const m = a.size(0);
293316
uint32_t const n = b.size(1);
317+
uint32_t const k = a.size(1);
294318

295319
if (m <= 16) {
296320
// m in [1, 16]
@@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
312336
// m in (64, 128]
313337
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM128>(
314338
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
339+
} else if (m >= 8192 && k >= 6144) {
340+
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM8192_K6144>(
341+
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
315342
} else {
316343
// m in (128, inf)
317344
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmDefault>(

0 commit comments

Comments
 (0)