33#include < cute/int_tuple.hpp>
44#include < cute/layout.hpp>
55
6- #include " mha_traits_sm80.h"
76#include " static_dispatch.h"
87
98namespace llm {
109// forward declaration
11- template <typename Traits ,
12- typename Params ,
10+ template <typename Dtype ,
11+ int HEAD_DIM ,
1312 bool EVEN_K,
1413 bool ALIBI,
1514 bool SOFT_CAP,
16- bool LOCAL>
15+ bool LOCAL,
16+ typename Params>
1717void launch_mha_kernel_sm80 (const Params& params, cudaStream_t stream);
1818
19- namespace detail {
19+ // user-facing function to run the attention kernel
20+ template <typename Dtype, int HEAD_DIM, typename Params>
21+ void run_mha_kernel_sm80 (Params& params, cudaStream_t stream = nullptr ) {
22+ // normalize params that for performance optimization
23+ params.normalize ();
2024
21- template <typename Traits, typename Params>
22- void dispatch_mha_kernel_sm80 (const Params& params, cudaStream_t stream) {
2325 // dispatch to proper kernel instantiation based on params
24- DISPATCH_BOOL (params.head_dim == Traits:: kHeadDim , EVEN_K, [&] {
26+ DISPATCH_BOOL (params.head_dim == HEAD_DIM , EVEN_K, [&] {
2527 DISPATCH_BOOL (params.alibi_slopes_ptr != nullptr , ALIBI, [&] {
2628 DISPATCH_BOOL (params.logits_soft_cap > 0 , SOFT_CAP, [&] {
2729 DISPATCH_BOOL (params.sliding_window >= 0 , LOCAL, [&] {
28- launch_mha_kernel_sm80<Traits ,
29- Params ,
30+ launch_mha_kernel_sm80<Dtype ,
31+ HEAD_DIM ,
3032 EVEN_K,
3133 ALIBI,
3234 SOFT_CAP,
33- LOCAL>(params, stream);
35+ LOCAL,
36+ Params>(params, stream);
3437 });
3538 });
3639 });
3740 });
3841}
3942
40- } // namespace detail
41-
42- // user-facing function to run the attention kernel
43- template <typename Dtype, int HEAD_DIM, typename Params>
44- void run_mha_kernel_sm80 (Params& params, cudaStream_t stream = nullptr ) {
45- // normalize params that for performance optimization
46- params.normalize ();
47-
48- // TODO: tune block shape MNK based on the head dim and smem size
49- // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
50- // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0|
51- // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 |
52- // valid dynamic shared memory sizes for different compute capabilities:
53- // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96
54- // * 7.5 : 0, 32, 64
55- // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164
56- // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100
57- // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228
58- // * 12.0 : 0, 8, 16, 32, 64, 100
59- if constexpr (HEAD_DIM == 64 ) {
60- using Traits = MHATraitsSM80<Dtype,
61- HEAD_DIM,
62- /* BLK_M=*/ 64 ,
63- /* BLK_N=*/ 64 ,
64- /* BLK_K=*/ 64 >;
65- detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
66- } else if constexpr (HEAD_DIM == 96 ) {
67- using Traits = MHATraitsSM80<Dtype,
68- HEAD_DIM,
69- /* BLK_M=*/ 64 ,
70- /* BLK_N=*/ 64 ,
71- /* BLK_K=*/ 32 >;
72- detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
73- } else if constexpr (HEAD_DIM == 128 ) {
74- using Traits = MHATraitsSM80<Dtype,
75- HEAD_DIM,
76- /* BLK_M=*/ 64 ,
77- /* BLK_N=*/ 64 ,
78- /* BLK_K=*/ 64 >;
79- detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
80- } else if constexpr (HEAD_DIM == 256 ) {
81- using Traits = MHATraitsSM80<Dtype,
82- HEAD_DIM,
83- /* BLK_M=*/ 64 ,
84- /* BLK_N=*/ 64 ,
85- /* BLK_K=*/ 64 >;
86- detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
87- } else {
88- // use the default block size
89- using Traits = MHATraitsSM80<Dtype,
90- HEAD_DIM,
91- /* BLK_M=*/ 64 ,
92- /* BLK_N=*/ 64 ,
93- /* BLK_K=*/ 64 >;
94- detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
95- }
96- }
97-
98- } // namespace llm
43+ } // namespace llm
0 commit comments