@@ -116,6 +116,26 @@ struct sm90_fp8_config_default {
116116 ClusterShape, KernelSchedule, EpilogueSchedule>>;
117117};
118118
119+ template <typename InType, typename OutType, bool EnableBias>
120+ struct sm90_fp8_config_M8192_K6144 {
121+ // M >= 8192, K >= 6144
122+ static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
123+ using KernelSchedule =
124+ cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
125+ using EpilogueSchedule =
126+ typename cutlass::epilogue::TmaWarpSpecializedCooperative;
127+ using TileShape = Shape<_256, _128, _128>;
128+ using ClusterShape = Shape<_2, _1, _1>;
129+
130+ using Cutlass3xGemm = conditional_t <
131+ EnableBias,
132+ cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogueBias,
133+ TileShape, ClusterShape, KernelSchedule,
134+ EpilogueSchedule>,
135+ cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
136+ ClusterShape, KernelSchedule, EpilogueSchedule>>;
137+ };
138+
119139template <typename InType, typename OutType, bool EnableBias>
120140struct sm90_fp8_config_M128 {
121141 // M in (64, 128]
@@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
273293 using Cutlass3xGemmDefault =
274294 typename sm90_fp8_config_default<InType, OutType,
275295 EnableBias>::Cutlass3xGemm;
296+ using Cutlass3xGemmM8192_K6144 =
297+ typename sm90_fp8_config_M8192_K6144<InType, OutType,
298+ EnableBias>::Cutlass3xGemm;
276299 using Cutlass3xGemmM128 =
277300 typename sm90_fp8_config_M128<InType, OutType, EnableBias>::Cutlass3xGemm;
278301
@@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
291314
292315 uint32_t const m = a.size (0 );
293316 uint32_t const n = b.size (1 );
317+ uint32_t const k = a.size (1 );
294318
295319 if (m <= 16 ) {
296320 // m in [1, 16]
@@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
312336 // m in (64, 128]
313337 return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM128>(
314338 out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
339+ } else if (m >= 8192 && k >= 6144 ) {
340+ return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM8192_K6144>(
341+ out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
315342 } else {
316343 // m in (128, inf)
317344 return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmDefault>(
0 commit comments