From 2f7cfcf6c0353b670235aea800369ed3754fc746 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Fri, 7 Nov 2025 18:00:45 +0800 Subject: [PATCH 1/6] mmf for rdna4 --- ggml/src/ggml-cuda/common.cuh | 12 +++ ggml/src/ggml-cuda/mma.cuh | 130 ++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmf.cu | 4 +- ggml/src/ggml-cuda/mmf.cuh | 167 +++++++++++++++++++++++++++------- 4 files changed, 276 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ca876459d404d..fddf4ab5923b1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_WMMA) +#define AMD_WMMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) + // The Volta instructions are in principle available on Turing or newer but they are effectively unusable: #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #define VOLTA_MMA_AVAILABLE @@ -283,6 +287,14 @@ static bool amd_mfma_available(const int cc) { #endif //!defined(GGML_HIP_NO_MMQ_MFMA) } +static bool amd_wmma_available(const int cc) { +#if !defined(GGML_HIP_NO_WMMA) + return GGML_CUDA_CC_IS_RDNA4(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_WMMA) +} + static bool volta_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a7a28fd1ae660..a481de01b4f3e 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -74,6 +74,31 @@ namespace ggml_cuda_mma { static constexpr int J = J_; #if defined(GGML_USE_HIP) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { + return 8 * (threadIdx.x / 16) + l; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } +#else static constexpr int ne = I * J / 64; T x[ne] = {0}; @@ -119,6 +144,7 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(RDNA4) #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -236,6 +262,32 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } +#endif // defined(RDNA4) #else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -285,6 +337,34 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; + +#if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } +#endif // defined(RDNA4) +#else static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -320,6 +400,7 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(AMD_WMMA_AVAILABLE) }; template @@ -353,6 +434,19 @@ namespace ggml_cuda_mma { const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; } +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + // Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162 + if constexpr (I == 16 && J == 8 && (std::is_same::value || std::is_same::value)) { + constexpr int RDNA4_WMMA_MEM_N = 4; + using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) int32_t; + reinterpret_cast(t.x[0]) = reinterpret_cast(xs0[t.get_i(0) * stride + t.get_j(0)]); + } else { + constexpr int RDNA4_WMMA_MEM_N = 8; + using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) T; + reinterpret_cast(t.x[0]) = reinterpret_cast(xs0[t.get_i(0) * stride + t.get_j(0)]); + } +#endif // defined(RDNA4) #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -639,12 +733,48 @@ namespace ggml_cuda_mma { : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#endif // defined(RDNA4) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { #if defined(AMD_MFMA_AVAILABLE) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 69a60aceb82b7..2f9f7822b448d 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -154,9 +154,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 45724e0911ec8..a5bf5eec3e29b 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -19,29 +19,15 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id); -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( +template +static __device__ __forceinline__ void mul_mat_f_impl( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; - +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -161,11 +147,31 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); +#if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; +#else + if constexpr (std::is_same::value) { + tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp); + } else if constexpr (std::is_same::value) { + tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp); + } else { + static_assert(0, "unsupported type"); + } +#endif // !defined(GGML_USE_HIP) } else { const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); +#if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; +#else + if constexpr (std::is_same::value) { + tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp); + } else if constexpr (std::is_same::value) { + tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp); + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#endif // !defined(GGML_USE_HIP) } } } else { @@ -239,35 +245,66 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } -//This kernel is for larger batch sizes of mul_mat_id -template +template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f_ids( - const T * __restrict__ x, const float * __restrict__ y, - const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, - const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const uint3 sis1_fd, const uint3 nch_fd) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + typedef tile<16, 8, T> tile_A; + typedef tile<16, 8, T> tile_B; + typedef tile<16, 16, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + constexpr bool supported = I_16_supported || I_32_supported; - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (supported) { + mul_mat_f_impl ( + x, y, ids, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst + ); + } else { + NO_DEVICE_CODE; + return; + } +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +} +//This kernel is for larger batch sizes of mul_mat_id +template +static __device__ __forceinline__ void mul_mat_f_ids_impl( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -408,7 +445,17 @@ static __global__ void mul_mat_f_ids( #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { const float2 tmp = vals_buf[curr_buf][j0]; +#if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; +#else + if constexpr (std::is_same::value) { + tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp); + } else if constexpr (std::is_same::value) { + tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp); + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#endif // !defined(GGML_USE_HIP) } if (itB + 1 < ntB) { @@ -492,7 +539,57 @@ static __global__ void mul_mat_f_ids( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +} + +//This kernel is for larger batch sizes of mul_mat_id +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + typedef tile<16, 8, T> tile_A; + typedef tile<16, 8, T> tile_B; + typedef tile<16, 16, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else + constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); + constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + constexpr bool supported = I_16_supported || I_32_supported; + + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. + + typedef tile tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (supported) { + mul_mat_f_ids_impl ( + x, y, + ids_src_compact, ids_dst_compact, + expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + sis1_fd, nch_fd + ); + } else { + NO_DEVICE_CODE; + return; + } +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } template From d564a35214f2f34de58babd13d7b54f080ab5748 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Fri, 7 Nov 2025 19:27:36 +0800 Subject: [PATCH 2/6] align the padding for rdna4 --- ggml/src/ggml-cuda/mmf.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index a5bf5eec3e29b..25c4ba445f7f8 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -651,7 +651,11 @@ void mul_mat_f_cuda( cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A_16; typedef tile<32, 8, T> tile_A_32; +#if defined(AMD_WMMA_AVAILABLE) + typedef tile<16, 8, T> tile_B; +#else typedef tile< 8, 8, T> tile_B; +#endif // defined(AMD_WMMA_AVAILABLE) GGML_ASSERT(ncols_x % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); From bbee5feba17d77cf7c31448657b76f22eef4a714 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Sun, 9 Nov 2025 16:28:11 +0800 Subject: [PATCH 3/6] forbit mul_mat_f for rdna4 --- ggml/src/ggml-cuda/mmf.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index be2ad1c6b65f9..5c51a22256a48 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16) { + if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) { return false; } } From fd18344cf15c479a488da9ae951d5cb3c6db56b5 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Tue, 11 Nov 2025 16:27:27 +0800 Subject: [PATCH 4/6] fix as comment --- ggml/src/ggml-cuda/common.cuh | 2 +- ggml/src/ggml-cuda/convert.cuh | 4 ++++ ggml/src/ggml-cuda/mma.cuh | 42 +++++++++++++--------------------- ggml/src/ggml-cuda/mmf.cuh | 35 +++++++--------------------- 4 files changed, 29 insertions(+), 54 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index fddf4ab5923b1..962a35cd10dff 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -224,7 +224,7 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) -#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_WMMA) +#if defined(GGML_USE_HIP) && defined(RDNA4) #define AMD_WMMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 8a5e08ef667e0..1e30d5e04fe74 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -39,6 +39,10 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22bfloat162_rn(x); } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a481de01b4f3e..c490a77b6f41e 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -87,7 +87,8 @@ namespace ggml_cuda_mma { if constexpr (I == 16 && J == 16) { return 8 * (threadIdx.x / 16) + l; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; + return -1; } } @@ -95,7 +96,8 @@ namespace ggml_cuda_mma { if constexpr (I == 16 && J == 16) { return threadIdx.x % 16; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; + return -1; } } #else @@ -263,7 +265,6 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA4) static constexpr int ne = I * J / 32; half2 x[ne] = {{0.0f, 0.0f}}; @@ -276,7 +277,8 @@ namespace ggml_cuda_mma { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; + return -1; } } @@ -284,10 +286,10 @@ namespace ggml_cuda_mma { if constexpr (I == 16 && J == 8) { return 4 * (threadIdx.x / 16) + l; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; + return -1; } } -#endif // defined(RDNA4) #else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -339,7 +341,6 @@ namespace ggml_cuda_mma { static constexpr int J = J_; #if defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA4) static constexpr int ne = I * J / 32; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -352,7 +353,8 @@ namespace ggml_cuda_mma { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; + return -1; } } @@ -360,10 +362,10 @@ namespace ggml_cuda_mma { if constexpr (I == 16 && J == 8) { return 4 * (threadIdx.x / 16) + l; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; + return -1; } } -#endif // defined(RDNA4) #else static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -435,18 +437,10 @@ namespace ggml_cuda_mma { xi[0] = xs[0]; } #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA4) - // Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162 - if constexpr (I == 16 && J == 8 && (std::is_same::value || std::is_same::value)) { - constexpr int RDNA4_WMMA_MEM_N = 4; - using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) int32_t; - reinterpret_cast(t.x[0]) = reinterpret_cast(xs0[t.get_i(0) * stride + t.get_j(0)]); - } else { - constexpr int RDNA4_WMMA_MEM_N = 8; - using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) T; - reinterpret_cast(t.x[0]) = reinterpret_cast(xs0[t.get_i(0) * stride + t.get_j(0)]); - } -#endif // defined(RDNA4) + constexpr int nbytes = sizeof(t.x); + // Special case for RDNA3 fp16 and bf16 wmma, the size is 32 bytes. + constexpr int alignment = nbytes > ggml_cuda_get_max_cpy_bytes() ? ggml_cuda_get_max_cpy_bytes() : nbytes; + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -734,14 +728,12 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA4) using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const halfx8_t& a_frag = reinterpret_cast(A.x[0]); const halfx8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); -#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -761,14 +753,12 @@ namespace ggml_cuda_mma { : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); -#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 25c4ba445f7f8..1766470440639 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -2,6 +2,7 @@ #include "mma.cuh" #include "common.cuh" +#include "convert.cuh" using namespace ggml_cuda_mma; @@ -150,13 +151,7 @@ static __device__ __forceinline__ void mul_mat_f_impl( #if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; #else - if constexpr (std::is_same::value) { - tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp); - } else if constexpr (std::is_same::value) { - tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp); - } else { - static_assert(0, "unsupported type"); - } + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); #endif // !defined(GGML_USE_HIP) } else { const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; @@ -164,13 +159,7 @@ static __device__ __forceinline__ void mul_mat_f_impl( #if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; #else - if constexpr (std::is_same::value) { - tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp); - } else if constexpr (std::is_same::value) { - tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp); - } else { - static_assert(std::is_same_v, "unsupported type"); - } + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); #endif // !defined(GGML_USE_HIP) } } @@ -448,13 +437,7 @@ static __device__ __forceinline__ void mul_mat_f_ids_impl( #if !defined(GGML_USE_HIP) tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; #else - if constexpr (std::is_same::value) { - tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp); - } else if constexpr (std::is_same::value) { - tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp); - } else { - static_assert(std::is_same_v, "unsupported type"); - } + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); #endif // !defined(GGML_USE_HIP) } @@ -651,11 +634,8 @@ void mul_mat_f_cuda( cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A_16; typedef tile<32, 8, T> tile_A_32; -#if defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, T> tile_B; -#else - typedef tile< 8, 8, T> tile_B; -#endif // defined(AMD_WMMA_AVAILABLE) + typedef tile<16, 8, T> tile_B_16; + typedef tile< 8, 8, T> tile_B_8; GGML_ASSERT(ncols_x % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); @@ -682,7 +662,8 @@ void mul_mat_f_cuda( constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; From 7a09e22e352f4b21405aa10b007e5707f6df23d9 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Tue, 11 Nov 2025 17:18:01 +0800 Subject: [PATCH 5/6] remove device kernels --- ggml/src/ggml-cuda/mmf.cuh | 131 +++++++++++++------------------------ 1 file changed, 46 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 1766470440639..0ee7a76cf5a58 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -20,15 +20,43 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id); -template -static __device__ __forceinline__ void mul_mat_f_impl( +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else + constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); + constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + constexpr bool supported = I_16_supported || I_32_supported; + + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. + + typedef tile tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if (!supported) { + NO_DEVICE_CODE; + return; + } + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -237,19 +265,25 @@ static __device__ __forceinline__ void mul_mat_f_impl( #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } -template +//This kernel is for larger batch sizes of mul_mat_id +template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, T> tile_A; - typedef tile<16, 8, T> tile_B; - typedef tile<16, 16, float> tile_C; + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; constexpr bool a_supported = tile_A::supported(); constexpr bool b_supported = tile_B::supported(); @@ -266,34 +300,11 @@ static __global__ void mul_mat_f( typedef tile<8, 8, T> tile_B; typedef tile tile_C; #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (supported) { - mul_mat_f_impl ( - x, y, ids, dst, - ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst - ); - } else { + if (!supported) { NO_DEVICE_CODE; return; } -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) -} -//This kernel is for larger batch sizes of mul_mat_id -template -static __device__ __forceinline__ void mul_mat_f_ids_impl( - const T * __restrict__ x, const float * __restrict__ y, - const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, - const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, - const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const uint3 sis1_fd, const uint3 nch_fd) { -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -525,56 +536,6 @@ static __device__ __forceinline__ void mul_mat_f_ids_impl( #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } -//This kernel is for larger batch sizes of mul_mat_id -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f_ids( - const T * __restrict__ x, const float * __restrict__ y, - const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, - const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, - const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const uint3 sis1_fd, const uint3 nch_fd) { -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) -#if defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, T> tile_A; - typedef tile<16, 8, T> tile_B; - typedef tile<16, 16, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; -#else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; -#endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (supported) { - mul_mat_f_ids_impl ( - x, y, - ids_src_compact, ids_dst_compact, - expert_bounds, dst, - ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, - sis1_fd, nch_fd - ); - } else { - NO_DEVICE_CODE; - return; - } -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) -} - template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, From c65dd59e2672b41797161d051e7ebc84212db3b0 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Tue, 11 Nov 2025 17:30:13 +0800 Subject: [PATCH 6/6] add constexpr for early return --- ggml/src/ggml-cuda/mmf.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 0ee7a76cf5a58..6d663c376db1d 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -52,7 +52,7 @@ static __global__ void mul_mat_f( typedef tile<8, 8, T> tile_B; typedef tile tile_C; #endif // defined(AMD_WMMA_AVAILABLE) - if (!supported) { + if constexpr (!supported) { NO_DEVICE_CODE; return; } @@ -300,7 +300,7 @@ static __global__ void mul_mat_f_ids( typedef tile<8, 8, T> tile_B; typedef tile tile_C; #endif // defined(AMD_WMMA_AVAILABLE) - if (!supported) { + if constexpr (!supported) { NO_DEVICE_CODE; return; }