Skip to content

Commit 96e8fee

Browse files
SM75 (Turing) support for FP6 kernel (#942)
* SM75 support for FP6 kernel * More consistent argument ordering in benchmark function * Add a note about SM75 support in the floatx README * Handle FP6 + SM75 + N>=64 edge case * Document changes made for FP6 SM75 support
1 parent 39ce823 commit 96e8fee

File tree

6 files changed

+126
-28
lines changed

6 files changed

+126
-28
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def benchmark(m: int, k: int, n: int):
4343

4444
for m in tqdm([1 << i for i in range(10)]):
4545
for n, k in zip(n_vals, k_vals):
46-
results.append(benchmark(m, n, k))
46+
results.append(benchmark(m, k, n))
4747

4848
df = pd.DataFrame(results)
4949
df.to_csv("fp6_llm_benchmark_results.csv", index=False)

torchao/csrc/cuda/fp6_llm/fp6_linear.cu

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,35 @@
1313
// limitations under the License.
1414
//
1515
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu
16+
//
17+
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
18+
// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
19+
//
1620

17-
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
21+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing
1822

1923
#include "kernel_matmul.cuh"
2024
#include "kernel_reduction.cuh"
2125

2226
#include <stdio.h>
2327
#include <assert.h>
2428

29+
inline bool isSM75GPU() {
30+
int device;
31+
cudaError_t err = cudaGetDevice(&device);
32+
if (err != cudaSuccess) {
33+
return false;
34+
}
35+
36+
cudaDeviceProp props;
37+
err = cudaGetDeviceProperties(&props, device);
38+
if (err != cudaSuccess) {
39+
return false;
40+
}
41+
42+
return (props.major == 7) && (props.minor == 5);
43+
}
44+
2545
template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>
2646
static void Kernel_Ex(cudaStream_t stream,
2747
const uint4 *Weight,
@@ -80,38 +100,51 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream,
80100
if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128;
81101
if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128;
82102

83-
if (Split_K == 1) {
84-
switch (N_PowerOf2) {
85-
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
86-
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
87-
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
88-
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
89-
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
90-
default: if (N_PowerOf2 % 128 != 0) {
91-
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
92-
return cudaErrorUnknown;
93-
}
94-
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
103+
if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) {
104+
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
105+
if (Split_K == 1) {
106+
Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);
107+
} else {
108+
Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);
95109
}
96-
}
97-
else {
98-
switch (N_PowerOf2) {
99-
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
100-
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
101-
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
102-
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
103-
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
104-
default: if (N_PowerOf2 % 128 != 0) {
105-
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
106-
return cudaErrorUnknown;
107-
}
108-
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
110+
} else {
111+
if (Split_K == 1) {
112+
switch (N_PowerOf2) {
113+
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
114+
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
115+
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
116+
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
117+
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
118+
default: if (N_PowerOf2 % 128 != 0) {
119+
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
120+
return cudaErrorUnknown;
121+
}
122+
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
123+
}
124+
}
125+
else {
126+
switch (N_PowerOf2) {
127+
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
128+
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
129+
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
130+
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
131+
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
132+
default: if (N_PowerOf2 % 128 != 0) {
133+
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
134+
return cudaErrorUnknown;
135+
}
136+
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
137+
}
109138
}
139+
}
140+
141+
if (Split_K != 1) {
110142
// Reduction for SplitK
111143
dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1);
112144
dim3 BlockDim(WARP_SIZE, 1, 1);
113145
SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K);
114146
}
147+
115148
return cudaGetLastError();
116149
}
117150

torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// limitations under the License.
1414
//
1515
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh
16+
//
17+
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
18+
// - Added __CUDA_ARCH__ guards such that async operations are only executed for SM80 and up
19+
//
1620

1721
#include "configs.h"
1822
#include "utils_gmem.cuh"
@@ -140,7 +144,9 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
140144
for(int j=0; j<REG_PER_THREAD_C_TENSOR_16_16; j++)
141145
c[i][j] = 0.0f;
142146
//
147+
#if __CUDA_ARCH__ >= 800
143148
cp_async_wait_all();
149+
#endif
144150
__syncthreads();
145151

146152
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -175,12 +181,16 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
175181
if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
176182
// copying B tile from GlobalMemory to SharedMemory
177183
CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
184+
#if __CUDA_ARCH__ >= 800
178185
cp_async_group_commit();
186+
#endif
179187
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs
180188
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2);
181189
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3);
182190
// Barriers and Synchronizations
191+
#if __CUDA_ARCH__ >= 800
183192
cp_async_wait_group<PIPELINE_LEVEL_GMEM-2>();
193+
#endif
184194
__syncthreads();
185195
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
186196
// Updating global PTRs

torchao/csrc/cuda/fp6_llm/ptx_mma.cuh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// limitations under the License.
1414
//
1515
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh
16+
//
17+
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
18+
// - Replaced m16n8k16 Tensor core operation with two m16n8k8 operations
19+
// - Accounted for a difference in expected parameters for the ldmatrix operation
1620

1721
/***************************************************************************
1822
* Copyright 2023 The FLash-LLM Authors. All rights reserved.
@@ -55,6 +59,14 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
5559
assert( warp_start_col==0 );
5660
#endif
5761

62+
#if __CUDA_ARCH__ == 750
63+
if (TilingConfig::WARP_COL_MMA_TENSORS==1) {
64+
// For .target sm_75, all threads must contain valid addresses for the 'ldmatrix' op. below. Otherwise, the behavior is undefined.
65+
// See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix
66+
// To avoid this, we make threads 16-32 point to the same smem addresses as threads 0-15 by changing the lane id.
67+
lane_id = lane_id % 16;
68+
}
69+
#endif
5870
int col = (lane_id%8) + (lane_id/16)*8;
5971
int row = (lane_id%16) / 8 * 8;
6072
uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row]));
@@ -80,6 +92,28 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
8092
__device__ __forceinline__ void
8193
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
8294
{
95+
#if __CUDA_ARCH__ == 750
96+
// m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops.
97+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
98+
"{ %0, %1, %2, %3},"
99+
"{ %4, %5},"
100+
"{ %6 },"
101+
"{ %7, %8, %9, %10 };"
102+
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
103+
: "r"(a[0]), "r"(a[1]),
104+
"r"(b[0]),
105+
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
106+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
107+
"{ %0, %1, %2, %3},"
108+
"{ %4, %5},"
109+
"{ %6 },"
110+
"{ %7, %8, %9, %10 };"
111+
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
112+
: "r"(a[2]), "r"(a[3]),
113+
"r"(b[1]),
114+
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
115+
116+
#else
83117
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
84118
"{ %0, %1, %2, %3},"
85119
"{ %4, %5, %6, %7 },"
@@ -89,6 +123,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
89123
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
90124
"r"(b[0]), "r"(b[1]),
91125
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
126+
#endif
92127
}
93128

94129
#endif

torchao/csrc/cuda/fp6_llm/utils_gmem.cuh

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// limitations under the License.
1414
//
1515
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
16+
//
17+
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
18+
// - Replaced asynchronous copy operations with vectorized loads
19+
//
1620

1721
#ifndef UTILS_GMEM_CUH
1822
#define UTILS_GMEM_CUH
@@ -39,7 +43,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR,
3943
GPTR_HALF += lane_id*8;
4044
#pragma unroll
4145
for(int i=0; i<SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE/16; i++) {
46+
#if __CUDA_ARCH__ == 750
47+
if (pred_guard) {
48+
float4* SPTR_VEC = reinterpret_cast<float4*>(SPTR_HALF);
49+
const float4* GPTR_VEC = reinterpret_cast<const float4*>(GPTR_HALF);
50+
SPTR_VEC[0] = GPTR_VEC[0];
51+
}
52+
#else
4253
cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard);
54+
#endif
4355
SPTR_HALF += 256; // Forward 512 Bytes
4456
GPTR_HALF += 256; // Forward 512 Bytes
4557
}
@@ -82,8 +94,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ Shar
8294
#pragma unroll
8395
for (int i = 0; i < MaxIteration; i++) {
8496
bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred;
97+
#if __CUDA_ARCH__ == 750
98+
if (AsyncCopyPred) {
99+
float4* SharedPtrVec = reinterpret_cast<float4*>(&(*SharedPTR)[line_offset]);
100+
const float4* GlobalPtrVec = reinterpret_cast<const float4*>(GlobalPTR);
101+
SharedPtrVec[0] = GlobalPtrVec[0];
102+
}
103+
#else
85104
cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
86-
//
105+
#endif
87106
GlobalPTR += NumOfGroups * GlobalStride;
88107
SharedPTR += NumOfGroups;
89108
}

torchao/dtypes/floatx/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape
4343
- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations.
4444
- Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1.
4545
- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results.
46+
- FP6 is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details).
4647

4748
## End-to-End benchmarks
4849

0 commit comments

Comments
 (0)