Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions ggml/src/ggml-cuda/conv2d-transpose.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
#include <algorithm>

#include "conv2d-transpose.cuh"
#include "ggml.h"
#include "convert.cuh"

#include <algorithm>

__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
const int out_h, const int kernel_w, const int kernel_h, const int stride,
const int c_in, const int c_out, const int batches) {
struct conv2d_transpose_params {
const int in_w;
const int in_h;
const int out_w;
const int out_h;
const int kernel_w;
const int kernel_h;
const int stride;
const int c_in;
const int c_out;
const int batches;
const int total;
};

template <typename T>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a better name for this should be kernel_t

static __global__ void conv2d_transpose_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const int in_w,
const int in_h,
const int out_w,
const int out_h,
const int kernel_w,
const int kernel_h,
const int stride,
const int c_in,
const int c_out,
const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;

const int total_elements = out_w * out_h * c_out * batches;
Expand All @@ -26,41 +50,77 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y_idx - kh;
if (in_y < 0 || in_y % stride) continue;
if (in_y < 0 || in_y % stride) {
continue;
}
in_y /= stride;
if (in_y >= in_h) continue;
if (in_y >= in_h) {
continue;
}

for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x_idx - kw;
if (in_x < 0 || in_x % stride) continue;
if (in_x < 0 || in_x % stride) {
continue;
}
in_x /= stride;
if (in_x >= in_w) continue;
if (in_x >= in_w) {
continue;
}

const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
const int kernel_idx =
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;

float input_val = input[input_idx];
half kern_val = kernel[kernel_idx];
T kern_val = kernel[kernel_idx];

accumulator += input_val * (float) kern_val;
accumulator += input_val * ggml_cuda_cast<float>(kern_val);
}
}
}

output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
}

template <typename T>
static void conv2d_transpose_cuda(const float * input,
const T * kernel,
float * output,
const conv2d_transpose_params & params,
cudaStream_t st) {
const int blocks = (params.total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
conv2d_transpose_kernel<T><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input, kernel, output, params.in_w, params.in_h, params.out_w, params.out_h, params.kernel_w, params.kernel_h,
params.stride, params.c_in, params.c_out, params.batches);
}

static void conv2d_transpose_cuda_f16(const float * input,
const half * kernel,
float * output,
const conv2d_transpose_params & params,
cudaStream_t st) {
conv2d_transpose_cuda<half>(input, kernel, output, params, st);
}

static void conv2d_transpose_cuda_f32(const float * input,
const float * kernel,
float * output,
const conv2d_transpose_params & params,
cudaStream_t st) {
conv2d_transpose_cuda<float>(input, kernel, output, params, st);
}
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];

GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);

const float * input_data = (const float *) input->data;
float * output_data = (float *) dst->data;
const half * kernel_data = (const half *) kernel->data;
const void * kernel_data = kernel->data;

const int input_w = input->ne[0];
const int input_h = input->ne[1];
Expand All @@ -83,9 +143,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(ggml_is_contiguous(dst));

const int total = (output_w * output_h * channels_out * batches);
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
conv2d_transpose_params params = { input_w, input_h, output_w, output_h, kernel_w, kernel_h,
stride, channels_in, channels_out, batches, total };

conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
if (kernel->type == GGML_TYPE_F16) {
conv2d_transpose_cuda_f16(input_data, (const half *) kernel_data, output_data, params, st);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need separate cuda_f16 and cuda_f32 functions here, you can straight away dispatch here to conv2d_transpose_cuda<type> and remove those two functions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to conv2d.cu for the current dispatching manner, and I thought there is some convention in llama.cpp 😂

template <typename T>
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
}
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
}

if (kernel->type == GGML_TYPE_F16) {
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
} else {
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
}

} else {
conv2d_transpose_cuda_f32(input_data, (const float *) kernel_data, output_data, params, st);
}
}
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/conv2d-transpose.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.cuh"

#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256

void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
33 changes: 22 additions & 11 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4306,28 +4306,33 @@ struct test_conv_transpose_1d : public test_case {

// GGML_OP_CONV_TRANSPOSE_2D
struct test_conv_transpose_2d : public test_case {
// Dimensions
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
const int stride;
// Types
const ggml_type type_kernel;

std::string vars() override {
return VARS_TO_STR3(ne_input, ne_kernel, stride);
return VARS_TO_STR4(type_kernel, ne_input, ne_kernel, stride);
}

double max_nmse_err() override {
return 5e-4; // The default 1e-7 is too small for Vulkan.
}

test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int stride = 1)
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
test_conv_transpose_2d(
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int stride = 1,
ggml_type type_kernel = GGML_TYPE_F16
) : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), type_kernel(type_kernel) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_set_name(input, "input");

ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");

ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
Expand Down Expand Up @@ -6558,8 +6563,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
// for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (ggml_type kernel_type : {GGML_TYPE_F16,}) {
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
}

test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
Expand Down Expand Up @@ -7484,9 +7492,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));

test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
// for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (ggml_type kernel_type : {GGML_TYPE_F16,}) {
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
}

test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));

Expand Down