diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ca876459d404d..962a35cd10dff 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#if defined(GGML_USE_HIP) && defined(RDNA4) +#define AMD_WMMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) + // The Volta instructions are in principle available on Turing or newer but they are effectively unusable: #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #define VOLTA_MMA_AVAILABLE @@ -283,6 +287,14 @@ static bool amd_mfma_available(const int cc) { #endif //!defined(GGML_HIP_NO_MMQ_MFMA) } +static bool amd_wmma_available(const int cc) { +#if !defined(GGML_HIP_NO_WMMA) + return GGML_CUDA_CC_IS_RDNA4(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_WMMA) +} + static bool volta_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; } diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 8a5e08ef667e0..1e30d5e04fe74 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -39,6 +39,10 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22bfloat162_rn(x); } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a7a28fd1ae660..c490a77b6f41e 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -74,6 +74,33 @@ namespace ggml_cuda_mma { static constexpr int J = J_; #if defined(GGML_USE_HIP) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { + return 8 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#else static constexpr int ne = I * J / 64; T x[ne] = {0}; @@ -119,6 +146,7 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(RDNA4) #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -236,6 +264,32 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } #else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -285,6 +339,34 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; + +#if defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#else static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -320,6 +402,7 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(AMD_WMMA_AVAILABLE) }; template @@ -353,6 +436,11 @@ namespace ggml_cuda_mma { const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; } +#elif defined(AMD_WMMA_AVAILABLE) + constexpr int nbytes = sizeof(t.x); + // Special case for RDNA3 fp16 and bf16 wmma, the size is 32 bytes. + constexpr int alignment = nbytes > ggml_cuda_get_max_cpy_bytes() ? ggml_cuda_get_max_cpy_bytes() : nbytes; + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -639,12 +727,44 @@ namespace ggml_cuda_mma { : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#elif defined(AMD_WMMA_AVAILABLE) + using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { #if defined(AMD_MFMA_AVAILABLE) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 153dd5a97d5a7..5c51a22256a48 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16) { + if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) { return false; } } @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 45724e0911ec8..6d663c376db1d 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -2,6 +2,7 @@ #include "mma.cuh" #include "common.cuh" +#include "convert.cuh" using namespace ggml_cuda_mma; @@ -27,20 +28,34 @@ static __global__ void mul_mat_f( const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } + constexpr bool supported = I_16_supported || I_32_supported; constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!supported) { + NO_DEVICE_CODE; + return; + } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -161,11 +176,19 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); +#if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; +#else + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); +#endif // !defined(GGML_USE_HIP) } else { const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); +#if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; +#else + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); +#endif // !defined(GGML_USE_HIP) } } } else { @@ -239,7 +262,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -253,20 +276,34 @@ static __global__ void mul_mat_f_ids( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + constexpr bool supported = I_16_supported || I_32_supported; - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!supported) { + NO_DEVICE_CODE; + return; + } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -408,7 +445,11 @@ static __global__ void mul_mat_f_ids( #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { const float2 tmp = vals_buf[curr_buf][j0]; +#if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; +#else + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); +#endif // !defined(GGML_USE_HIP) } if (itB + 1 < ntB) { @@ -492,7 +533,7 @@ static __global__ void mul_mat_f_ids( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } template @@ -554,7 +595,8 @@ void mul_mat_f_cuda( cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A_16; typedef tile<32, 8, T> tile_A_32; - typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, T> tile_B_16; + typedef tile< 8, 8, T> tile_B_8; GGML_ASSERT(ncols_x % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); @@ -581,7 +623,8 @@ void mul_mat_f_cuda( constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;