|
13 | 13 | // limitations under the License. |
14 | 14 | // |
15 | 15 | // This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu |
| 16 | +// |
| 17 | +// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): |
| 18 | +// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory |
| 19 | +// |
16 | 20 |
|
17 | | -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere |
| 21 | +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing |
18 | 22 |
|
19 | 23 | #include "kernel_matmul.cuh" |
20 | 24 | #include "kernel_reduction.cuh" |
21 | 25 |
|
22 | 26 | #include <stdio.h> |
23 | 27 | #include <assert.h> |
24 | 28 |
|
| 29 | +inline bool isSM75GPU() { |
| 30 | + int device; |
| 31 | + cudaError_t err = cudaGetDevice(&device); |
| 32 | + if (err != cudaSuccess) { |
| 33 | + return false; |
| 34 | + } |
| 35 | + |
| 36 | + cudaDeviceProp props; |
| 37 | + err = cudaGetDeviceProperties(&props, device); |
| 38 | + if (err != cudaSuccess) { |
| 39 | + return false; |
| 40 | + } |
| 41 | + |
| 42 | + return (props.major == 7) && (props.minor == 5); |
| 43 | +} |
| 44 | + |
25 | 45 | template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA> |
26 | 46 | static void Kernel_Ex(cudaStream_t stream, |
27 | 47 | const uint4 *Weight, |
@@ -80,38 +100,51 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, |
80 | 100 | if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; |
81 | 101 | if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; |
82 | 102 |
|
83 | | - if (Split_K == 1) { |
84 | | - switch (N_PowerOf2) { |
85 | | - case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
86 | | - case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
87 | | - case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
88 | | - case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
89 | | - case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
90 | | - default: if (N_PowerOf2 % 128 != 0) { |
91 | | - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); |
92 | | - return cudaErrorUnknown; |
93 | | - } |
94 | | - Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 103 | + if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { |
| 104 | + // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. |
| 105 | + if (Split_K == 1) { |
| 106 | + Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); |
| 107 | + } else { |
| 108 | + Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); |
95 | 109 | } |
96 | | - } |
97 | | - else { |
98 | | - switch (N_PowerOf2) { |
99 | | - case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
100 | | - case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
101 | | - case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
102 | | - case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
103 | | - case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
104 | | - default: if (N_PowerOf2 % 128 != 0) { |
105 | | - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); |
106 | | - return cudaErrorUnknown; |
107 | | - } |
108 | | - Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 110 | + } else { |
| 111 | + if (Split_K == 1) { |
| 112 | + switch (N_PowerOf2) { |
| 113 | + case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 114 | + case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 115 | + case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 116 | + case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 117 | + case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 118 | + default: if (N_PowerOf2 % 128 != 0) { |
| 119 | + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); |
| 120 | + return cudaErrorUnknown; |
| 121 | + } |
| 122 | + Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; |
| 123 | + } |
| 124 | + } |
| 125 | + else { |
| 126 | + switch (N_PowerOf2) { |
| 127 | + case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 128 | + case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 129 | + case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 130 | + case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 131 | + case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 132 | + default: if (N_PowerOf2 % 128 != 0) { |
| 133 | + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); |
| 134 | + return cudaErrorUnknown; |
| 135 | + } |
| 136 | + Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; |
| 137 | + } |
109 | 138 | } |
| 139 | + } |
| 140 | + |
| 141 | + if (Split_K != 1) { |
110 | 142 | // Reduction for SplitK |
111 | 143 | dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); |
112 | 144 | dim3 BlockDim(WARP_SIZE, 1, 1); |
113 | 145 | SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); |
114 | 146 | } |
| 147 | + |
115 | 148 | return cudaGetLastError(); |
116 | 149 | } |
117 | 150 |
|
|
0 commit comments