-
Notifications
You must be signed in to change notification settings - Fork 13.7k
CUDA: support F32 kernel type for CONV_TRANSPOSE_2D
#17094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,12 +1,36 @@ | ||||||||||||||||||||||||||||||||||||||
| #include <algorithm> | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include "conv2d-transpose.cuh" | ||||||||||||||||||||||||||||||||||||||
| #include "ggml.h" | ||||||||||||||||||||||||||||||||||||||
| #include "convert.cuh" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include <algorithm> | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, | ||||||||||||||||||||||||||||||||||||||
| float * __restrict__ output, const int in_w, const int in_h, const int out_w, | ||||||||||||||||||||||||||||||||||||||
| const int out_h, const int kernel_w, const int kernel_h, const int stride, | ||||||||||||||||||||||||||||||||||||||
| const int c_in, const int c_out, const int batches) { | ||||||||||||||||||||||||||||||||||||||
| struct conv2d_transpose_params { | ||||||||||||||||||||||||||||||||||||||
| const int in_w; | ||||||||||||||||||||||||||||||||||||||
| const int in_h; | ||||||||||||||||||||||||||||||||||||||
| const int out_w; | ||||||||||||||||||||||||||||||||||||||
| const int out_h; | ||||||||||||||||||||||||||||||||||||||
| const int kernel_w; | ||||||||||||||||||||||||||||||||||||||
| const int kernel_h; | ||||||||||||||||||||||||||||||||||||||
| const int stride; | ||||||||||||||||||||||||||||||||||||||
| const int c_in; | ||||||||||||||||||||||||||||||||||||||
| const int c_out; | ||||||||||||||||||||||||||||||||||||||
| const int batches; | ||||||||||||||||||||||||||||||||||||||
| const int total; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||
| static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, | ||||||||||||||||||||||||||||||||||||||
| const T * __restrict__ kernel, | ||||||||||||||||||||||||||||||||||||||
| float * __restrict__ output, | ||||||||||||||||||||||||||||||||||||||
| const int in_w, | ||||||||||||||||||||||||||||||||||||||
| const int in_h, | ||||||||||||||||||||||||||||||||||||||
| const int out_w, | ||||||||||||||||||||||||||||||||||||||
| const int out_h, | ||||||||||||||||||||||||||||||||||||||
| const int kernel_w, | ||||||||||||||||||||||||||||||||||||||
| const int kernel_h, | ||||||||||||||||||||||||||||||||||||||
| const int stride, | ||||||||||||||||||||||||||||||||||||||
| const int c_in, | ||||||||||||||||||||||||||||||||||||||
| const int c_out, | ||||||||||||||||||||||||||||||||||||||
| const int batches) { | ||||||||||||||||||||||||||||||||||||||
| const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const int total_elements = out_w * out_h * c_out * batches; | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -26,41 +50,77 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const | |||||||||||||||||||||||||||||||||||||
| for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { | ||||||||||||||||||||||||||||||||||||||
| for (int kh = 0; kh < kernel_h; ++kh) { | ||||||||||||||||||||||||||||||||||||||
| int in_y = out_y_idx - kh; | ||||||||||||||||||||||||||||||||||||||
| if (in_y < 0 || in_y % stride) continue; | ||||||||||||||||||||||||||||||||||||||
| if (in_y < 0 || in_y % stride) { | ||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| in_y /= stride; | ||||||||||||||||||||||||||||||||||||||
| if (in_y >= in_h) continue; | ||||||||||||||||||||||||||||||||||||||
| if (in_y >= in_h) { | ||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for (int kw = 0; kw < kernel_w; ++kw) { | ||||||||||||||||||||||||||||||||||||||
| int in_x = out_x_idx - kw; | ||||||||||||||||||||||||||||||||||||||
| if (in_x < 0 || in_x % stride) continue; | ||||||||||||||||||||||||||||||||||||||
| if (in_x < 0 || in_x % stride) { | ||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| in_x /= stride; | ||||||||||||||||||||||||||||||||||||||
| if (in_x >= in_w) continue; | ||||||||||||||||||||||||||||||||||||||
| if (in_x >= in_w) { | ||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; | ||||||||||||||||||||||||||||||||||||||
| const int kernel_idx = | ||||||||||||||||||||||||||||||||||||||
| (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float input_val = input[input_idx]; | ||||||||||||||||||||||||||||||||||||||
| half kern_val = kernel[kernel_idx]; | ||||||||||||||||||||||||||||||||||||||
| T kern_val = kernel[kernel_idx]; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| accumulator += input_val * (float) kern_val; | ||||||||||||||||||||||||||||||||||||||
| accumulator += input_val * ggml_cuda_cast<float>(kern_val); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||
| static void conv2d_transpose_cuda(const float * input, | ||||||||||||||||||||||||||||||||||||||
| const T * kernel, | ||||||||||||||||||||||||||||||||||||||
| float * output, | ||||||||||||||||||||||||||||||||||||||
| const conv2d_transpose_params & params, | ||||||||||||||||||||||||||||||||||||||
| cudaStream_t st) { | ||||||||||||||||||||||||||||||||||||||
| const int blocks = (params.total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; | ||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_kernel<T><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( | ||||||||||||||||||||||||||||||||||||||
| input, kernel, output, params.in_w, params.in_h, params.out_w, params.out_h, params.kernel_w, params.kernel_h, | ||||||||||||||||||||||||||||||||||||||
| params.stride, params.c_in, params.c_out, params.batches); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| static void conv2d_transpose_cuda_f16(const float * input, | ||||||||||||||||||||||||||||||||||||||
| const half * kernel, | ||||||||||||||||||||||||||||||||||||||
| float * output, | ||||||||||||||||||||||||||||||||||||||
| const conv2d_transpose_params & params, | ||||||||||||||||||||||||||||||||||||||
| cudaStream_t st) { | ||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_cuda<half>(input, kernel, output, params, st); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| static void conv2d_transpose_cuda_f32(const float * input, | ||||||||||||||||||||||||||||||||||||||
| const float * kernel, | ||||||||||||||||||||||||||||||||||||||
| float * output, | ||||||||||||||||||||||||||||||||||||||
| const conv2d_transpose_params & params, | ||||||||||||||||||||||||||||||||||||||
| cudaStream_t st) { | ||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_cuda<float>(input, kernel, output, params, st); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| //input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in) | ||||||||||||||||||||||||||||||||||||||
| void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||||||||||||||||||||||||||||||||||
| const ggml_tensor * kernel = dst->src[0]; | ||||||||||||||||||||||||||||||||||||||
| const ggml_tensor * input = dst->src[1]; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); | ||||||||||||||||||||||||||||||||||||||
| GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); | ||||||||||||||||||||||||||||||||||||||
| GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float * input_data = (const float *) input->data; | ||||||||||||||||||||||||||||||||||||||
| float * output_data = (float *) dst->data; | ||||||||||||||||||||||||||||||||||||||
| const half * kernel_data = (const half *) kernel->data; | ||||||||||||||||||||||||||||||||||||||
| const void * kernel_data = kernel->data; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const int input_w = input->ne[0]; | ||||||||||||||||||||||||||||||||||||||
| const int input_h = input->ne[1]; | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -83,9 +143,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor | |||||||||||||||||||||||||||||||||||||
| GGML_ASSERT(ggml_is_contiguous(dst)); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const int total = (output_w * output_h * channels_out * batches); | ||||||||||||||||||||||||||||||||||||||
| const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; | ||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_params params = { input_w, input_h, output_w, output_h, kernel_w, kernel_h, | ||||||||||||||||||||||||||||||||||||||
| stride, channels_in, channels_out, batches, total }; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( | ||||||||||||||||||||||||||||||||||||||
| input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, | ||||||||||||||||||||||||||||||||||||||
| channels_in, channels_out, batches); | ||||||||||||||||||||||||||||||||||||||
| if (kernel->type == GGML_TYPE_F16) { | ||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_cuda_f16(input_data, (const half *) kernel_data, output_data, params, st); | ||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need separate cuda_f16 and cuda_f32 functions here, you can straight away dispatch here to
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I referred to llama.cpp/ggml/src/ggml-cuda/conv2d.cu Lines 108 to 120 in 8e878f0
llama.cpp/ggml/src/ggml-cuda/conv2d.cu Lines 161 to 165 in 8e878f0
|
||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||
| conv2d_transpose_cuda_f32(input_data, (const float *) kernel_data, output_data, params, st); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| #include "common.cuh" | ||
|
|
||
| #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 | ||
|
|
||
| void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably a better name for this should be
kernel_t