diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cu b/ggml/src/ggml-cuda/conv2d-transpose.cu index 03224e404d32d..c0fc34f0c5ab1 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cu +++ b/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -1,12 +1,36 @@ -#include - #include "conv2d-transpose.cuh" -#include "ggml.h" +#include "convert.cuh" + +#include -__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, const int in_w, const int in_h, const int out_w, - const int out_h, const int kernel_w, const int kernel_h, const int stride, - const int c_in, const int c_out, const int batches) { +struct conv2d_transpose_params { + const int in_w; + const int in_h; + const int out_w; + const int out_h; + const int kernel_w; + const int kernel_h; + const int stride; + const int c_in; + const int c_out; + const int batches; + const int total; +}; + +template +static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const int in_w, + const int in_h, + const int out_w, + const int out_h, + const int kernel_w, + const int kernel_h, + const int stride, + const int c_in, + const int c_out, + const int batches) { const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; const int total_elements = out_w * out_h * c_out * batches; @@ -26,24 +50,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { for (int kh = 0; kh < kernel_h; ++kh) { int in_y = out_y_idx - kh; - if (in_y < 0 || in_y % stride) continue; + if (in_y < 0 || in_y % stride) { + continue; + } in_y /= stride; - if (in_y >= in_h) continue; + if (in_y >= in_h) { + continue; + } for (int kw = 0; kw < kernel_w; ++kw) { int in_x = out_x_idx - kw; - if (in_x < 0 || in_x % stride) continue; + if (in_x < 0 || in_x % stride) { + continue; + } in_x /= stride; - if (in_x >= in_w) continue; + if (in_x >= in_w) { + continue; + } const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; float input_val = input[input_idx]; - half kern_val = kernel[kernel_idx]; + T kern_val = kernel[kernel_idx]; - accumulator += input_val * (float) kern_val; + accumulator += input_val * ggml_cuda_cast(kern_val); } } } @@ -51,16 +83,44 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator; } +template +static void conv2d_transpose_cuda(const float * input, + const T * kernel, + float * output, + const conv2d_transpose_params & params, + cudaStream_t st) { + const int blocks = (params.total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; + conv2d_transpose_kernel<<>>( + input, kernel, output, params.in_w, params.in_h, params.out_w, params.out_h, params.kernel_w, params.kernel_h, + params.stride, params.c_in, params.c_out, params.batches); +} + +static void conv2d_transpose_cuda_f16(const float * input, + const half * kernel, + float * output, + const conv2d_transpose_params & params, + cudaStream_t st) { + conv2d_transpose_cuda(input, kernel, output, params, st); +} + +static void conv2d_transpose_cuda_f32(const float * input, + const float * kernel, + float * output, + const conv2d_transpose_params & params, + cudaStream_t st) { + conv2d_transpose_cuda(input, kernel, output, params, st); +} //input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in) void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * kernel = dst->src[0]; const ggml_tensor * input = dst->src[1]; - GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); const float * input_data = (const float *) input->data; float * output_data = (float *) dst->data; - const half * kernel_data = (const half *) kernel->data; + const void * kernel_data = kernel->data; const int input_w = input->ne[0]; const int input_h = input->ne[1]; @@ -83,9 +143,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(ggml_is_contiguous(dst)); const int total = (output_w * output_h * channels_out * batches); - const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; + conv2d_transpose_params params = { input_w, input_h, output_w, output_h, kernel_w, kernel_h, + stride, channels_in, channels_out, batches, total }; - conv2d_transpose_kernel<<>>( - input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, - channels_in, channels_out, batches); + if (kernel->type == GGML_TYPE_F16) { + conv2d_transpose_cuda_f16(input_data, (const half *) kernel_data, output_data, params, st); + } else { + conv2d_transpose_cuda_f32(input_data, (const float *) kernel_data, output_data, params, st); + } } diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cuh b/ggml/src/ggml-cuda/conv2d-transpose.cuh index c9430b2485021..72889c5f0fa89 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cuh +++ b/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -1,4 +1,5 @@ #include "common.cuh" #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 + void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index bffd60f386ae2..70165a3f39842 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4306,28 +4306,33 @@ struct test_conv_transpose_1d : public test_case { // GGML_OP_CONV_TRANSPOSE_2D struct test_conv_transpose_2d : public test_case { + // Dimensions const std::array ne_input; const std::array ne_kernel; const int stride; + // Types + const ggml_type type_kernel; std::string vars() override { - return VARS_TO_STR3(ne_input, ne_kernel, stride); + return VARS_TO_STR4(type_kernel, ne_input, ne_kernel, stride); } double max_nmse_err() override { return 5e-4; // The default 1e-7 is too small for Vulkan. } - test_conv_transpose_2d(std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] - std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] - int stride = 1) - : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){} + test_conv_transpose_2d( + std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] + std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] + int stride = 1, + ggml_type type_kernel = GGML_TYPE_F16 + ) : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), type_kernel(type_kernel) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); ggml_set_name(input, "input"); - ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data()); + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); ggml_set_name(kernel, "kernel"); ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride); @@ -6558,8 +6563,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); + // for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (ggml_type kernel_type : {GGML_TYPE_F16,}) { + test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type)); + test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type)); + } test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1})); test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1})); @@ -7484,9 +7492,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true)); - test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); + // for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (ggml_type kernel_type : {GGML_TYPE_F16,}) { + test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1, kernel_type)); + test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1, kernel_type)); + test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type)); + } test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));