2727#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
2828
2929#define VK_VENDOR_ID_AMD 0x1002
30+ #define VK_VENDOR_ID_APPLE 0x106b
3031#define VK_VENDOR_ID_INTEL 0x8086
3132#define VK_VENDOR_ID_NVIDIA 0x10de
3233
@@ -2034,18 +2035,100 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct
20342035 return ctx->pipeline_matmul_f32_aligned_l .align ;
20352036}
20362037
2038+ static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2039+ if (bit16_x && bit16_y) {
2040+ if (m <= 32 || n <= 32 ) {
2041+ #ifdef GGML_VULKAN_DEBUG
2042+ std::cerr << " S" << std::endl;
2043+ #endif
2044+ return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s ;
2045+ }
2046+ #ifdef GGML_VULKAN_DEBUG
2047+ std::cerr << " M" << std::endl;
2048+ #endif
2049+ return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m ;
2050+ }
2051+ if (bit16_x && !bit16_y) {
2052+ if (m <= 32 || n <= 32 ) {
2053+ #ifdef GGML_VULKAN_DEBUG
2054+ std::cerr << " S" << std::endl;
2055+ #endif
2056+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s ;
2057+ }
2058+ #ifdef GGML_VULKAN_DEBUG
2059+ std::cerr << " M" << std::endl;
2060+ #endif
2061+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m ;
2062+ }
2063+ if (!bit16_x && bit16_y) {
2064+ GGML_ASSERT (false );
2065+ }
2066+
2067+ if (m <= 32 || n <= 32 ) {
2068+ #ifdef GGML_VULKAN_DEBUG
2069+ std::cerr << " S" << std::endl;
2070+ #endif
2071+ return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s ;
2072+ }
2073+ #ifdef GGML_VULKAN_DEBUG
2074+ std::cerr << " M" << std::endl;
2075+ #endif
2076+ return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m ;
2077+ }
2078+
2079+ static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2080+ #ifdef GGML_VULKAN_DEBUG
2081+ std::cerr << " M" << std::endl;
2082+ #endif
2083+ if (bit16_x && bit16_y) {
2084+ return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m ;
2085+ }
2086+ if (bit16_x && !bit16_y) {
2087+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m ;
2088+ }
2089+ if (!bit16_x && bit16_y) {
2090+ GGML_ASSERT (false );
2091+ }
2092+ return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m ;
2093+ }
2094+
2095+ static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2096+ #ifdef GGML_VULKAN_DEBUG
2097+ std::cerr << " S" << std::endl;
2098+ #endif
2099+ if (bit16_x && bit16_y) {
2100+ return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s ;
2101+ }
2102+ if (bit16_x && !bit16_y) {
2103+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s ;
2104+ }
2105+ if (!bit16_x && bit16_y) {
2106+ GGML_ASSERT (false );
2107+ }
2108+ return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s ;
2109+ }
2110+
20372111static vk_pipeline* ggml_vk_guess_matmul_pipeline (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
20382112#ifdef GGML_VULKAN_DEBUG
20392113 std::cerr << " ggml_vk_guess_matmul_pipeline(" << bit16_x << " , " << bit16_y << " , " << m << " , " << n << " , " << aligned << " )" ;
20402114#endif
2115+ switch (ctx->device .lock ()->vendor_id ) {
2116+ case VK_VENDOR_ID_AMD:
2117+ return ggml_vk_guess_matmul_pipeline_amd (ctx, bit16_x, bit16_y, m, n, aligned);
2118+ case VK_VENDOR_ID_APPLE:
2119+ return ggml_vk_guess_matmul_pipeline_apple (ctx, bit16_x, bit16_y, aligned);
2120+ case VK_VENDOR_ID_INTEL:
2121+ return ggml_vk_guess_matmul_pipeline_intel (ctx, bit16_x, bit16_y, aligned);
2122+ }
2123+
20412124 if (bit16_x && bit16_y) {
2042- if (ctx-> device . lock ()-> vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32 ) {
2125+ if (m <= 32 || n <= 32 ) {
20432126#ifdef GGML_VULKAN_DEBUG
20442127 std::cerr << " S" << std::endl;
20452128#endif
20462129 return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s ;
20472130 }
2048- if (ctx-> device . lock ()-> subgroup_size == 64 || m <= 64 || n <= 64 ) {
2131+ if (m <= 64 || n <= 64 ) {
20492132#ifdef GGML_VULKAN_DEBUG
20502133 std::cerr << " M" << std::endl;
20512134#endif
@@ -2057,13 +2140,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
20572140 return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l ;
20582141 }
20592142 if (bit16_x && !bit16_y) {
2060- if (ctx-> device . lock ()-> vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32 ) {
2143+ if (m <= 32 || n <= 32 ) {
20612144#ifdef GGML_VULKAN_DEBUG
20622145 std::cerr << " S" << std::endl;
20632146#endif
20642147 return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s ;
20652148 }
2066- if (ctx-> device . lock ()-> subgroup_size == 64 || m <= 64 || n <= 64 ) {
2149+ if (m <= 64 || n <= 64 ) {
20672150#ifdef GGML_VULKAN_DEBUG
20682151 std::cerr << " M" << std::endl;
20692152#endif
@@ -2078,13 +2161,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
20782161 GGML_ASSERT (false );
20792162 }
20802163
2081- if (ctx-> device . lock ()-> vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32 ) {
2164+ if (m <= 32 || n <= 32 ) {
20822165#ifdef GGML_VULKAN_DEBUG
20832166 std::cerr << " S" << std::endl;
20842167#endif
20852168 return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s ;
20862169 }
2087- if (ctx-> device . lock ()-> subgroup_size == 64 || m <= 64 || n <= 64 ) {
2170+ if (m <= 64 || n <= 64 ) {
20882171#ifdef GGML_VULKAN_DEBUG
20892172 std::cerr << " M" << std::endl;
20902173#endif
0 commit comments