Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if defined(GGML_USE_HIP) && defined(RDNA4)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE
Expand Down Expand Up @@ -283,6 +287,14 @@ static bool amd_mfma_available(const int cc) {
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
}

static bool amd_wmma_available(const int cc) {
#if !defined(GGML_HIP_NO_WMMA)
return GGML_CUDA_CC_IS_RDNA4(cc);
#else
return false;
#endif //!defined(GGML_HIP_NO_WMMA)
}

static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
return __float22bfloat162_rn(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
Expand Down
120 changes: 120 additions & 0 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
static constexpr int J = J_;

#if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / 64;
T x[ne] = {0};

Expand Down Expand Up @@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
return -1;
}
}
#elif defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
Expand Down Expand Up @@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

Expand Down Expand Up @@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(AMD_WMMA_AVAILABLE)
};

template <int I, int J>
Expand Down Expand Up @@ -353,6 +436,11 @@ namespace ggml_cuda_mma {
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
constexpr int nbytes = sizeof(t.x);
// Special case for RDNA3 fp16 and bf16 wmma, the size is 32 bytes.
constexpr int alignment = nbytes > ggml_cuda_get_max_cpy_bytes() ? ggml_cuda_get_max_cpy_bytes() : nbytes;
ggml_cuda_memcpy_1<sizeof(t.x), alignment>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
Expand Down Expand Up @@ -639,12 +727,44 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#elif defined(AMD_WMMA_AVAILABLE)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MFMA_AVAILABLE)
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16) {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
return false;
}
}
Expand All @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc);
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_wmma_available(cc);
default:
return false;
}
Expand Down
77 changes: 60 additions & 17 deletions ggml/src/ggml-cuda/mmf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "mma.cuh"
#include "common.cuh"
#include "convert.cuh"

using namespace ggml_cuda_mma;

Expand All @@ -27,20 +28,34 @@ static __global__ void mul_mat_f(
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;

constexpr bool a_supported = tile_A::supported();
constexpr bool b_supported = tile_B::supported();
constexpr bool c_supported = tile_C::supported();
constexpr bool supported = a_supported && b_supported && c_supported;
#else
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();

if (!I_16_supported && !I_32_supported) {
NO_DEVICE_CODE;
return;
}
constexpr bool supported = I_16_supported || I_32_supported;

constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.

typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C;
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!supported) {
NO_DEVICE_CODE;
return;
}

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
Expand Down Expand Up @@ -161,11 +176,19 @@ static __global__ void mul_mat_f(

if constexpr (!has_ids) {
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp);
#endif // !defined(GGML_USE_HIP)
} else {
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp);
#endif // !defined(GGML_USE_HIP)
}
}
} else {
Expand Down Expand Up @@ -239,7 +262,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
}

//This kernel is for larger batch sizes of mul_mat_id
Expand All @@ -253,20 +276,34 @@ static __global__ void mul_mat_f_ids(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;

constexpr bool a_supported = tile_A::supported();
constexpr bool b_supported = tile_B::supported();
constexpr bool c_supported = tile_C::supported();
constexpr bool supported = a_supported && b_supported && c_supported;
#else
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
constexpr bool supported = I_16_supported || I_32_supported;

if (!I_16_supported && !I_32_supported) {
NO_DEVICE_CODE;
return;
}

constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.

typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C;
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!supported) {
NO_DEVICE_CODE;
return;
}

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
Expand Down Expand Up @@ -408,7 +445,11 @@ static __global__ void mul_mat_f_ids(
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0];
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp);
#endif // !defined(GGML_USE_HIP)
}

if (itB + 1 < ntB) {
Expand Down Expand Up @@ -492,7 +533,7 @@ static __global__ void mul_mat_f_ids(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
}

template<typename T, int cols_per_block, int nwarps>
Expand Down Expand Up @@ -554,7 +595,8 @@ void mul_mat_f_cuda(
cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A_16;
typedef tile<32, 8, T> tile_A_32;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, T> tile_B_16;
typedef tile< 8, 8, T> tile_B_8;

GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
Expand All @@ -581,7 +623,8 @@ void mul_mat_f_cuda(

constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
Expand Down