diff --git a/classification/ops_dcnv3/setup.py b/classification/ops_dcnv3/setup.py index 72f112b1..7951c67f 100644 --- a/classification/ops_dcnv3/setup.py +++ b/classification/ops_dcnv3/setup.py @@ -24,7 +24,7 @@ def get_extensions(): sources = main_file + source_cpu extension = CppExtension - extra_compile_args = {'cxx': []} + extra_compile_args = {'cxx': ['-O2']} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: @@ -37,8 +37,9 @@ def get_extensions(): # "-D__CUDA_NO_HALF_CONVERSIONS__", # "-D__CUDA_NO_HALF2_OPERATORS__", ] + print("CUDA is available, building with CUDA support") else: - raise NotImplementedError('Cuda is not availabel') + print("CUDA is not available, building CPU-only version") sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] diff --git a/classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp b/classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp index a3bddc18..a5b8ec29 100644 --- a/classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp +++ b/classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp @@ -1,18 +1,24 @@ /*! ************************************************************************************************** -* InternImage -* Copyright (c) 2022 OpenGVLab -* Licensed under The MIT License [see LICENSE for details] -************************************************************************************************** -* Modified from -*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 ************************************************************************************************** */ #include #include -#include +#include +#include + +#include "cpu/dcnv3_im2col_cpu.h" +#include "dcnv3_cpu.h" +#include "dcnv3_im2col_cpu.h" +#include +#include +#include +#include at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, @@ -21,8 +27,63 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const int im2col_step) { - AT_ERROR("Not implement on cpu"); + const int im2col_step, const int remove_center) { + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor"); + AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto output = + at::zeros({batch, height_out, width_out, group * group_channels}, + input.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch / batch_n, batch_n, height_out, + width_out, group * group_channels}); + + const int num_batches = batch / im2col_step_; + + at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) { + for (int64_t n = start_n; n < end_n; ++n) { + auto input_slice = input.narrow(0, n * batch_n, batch_n); + auto offset_slice = offset.narrow(0, n * batch_n, batch_n); + auto mask_slice = mask.narrow(0, n * batch_n, batch_n); + auto columns = output_n.select(0, n); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dcnv3_cpu_forward", [&] { + dcnv3::cpu::dcnv3_im2col_cpu( + input_slice, offset_slice, mask_slice, columns, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + batch_n, height_in, width_in, height_out, width_out, + offset_scale, remove_center); + }); + } + }); + + return output; } std::vector @@ -32,6 +93,72 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const at::Tensor &grad_output, const int im2col_step) { - AT_ERROR("Not implement on cpu"); + const at::Tensor &grad_output, const int im2col_step, const int remove_center) { + + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor"); + AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor"); + AT_ASSERTM(grad_output.is_cpu(), + "grad_output must be a CPU tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + const int batch_n = im2col_step_; + + auto grad_output_n = + grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, + group, group_channels}); + + const int num_batches = batch / im2col_step_; + + at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) { + for (int64_t n = start_n; n < end_n; ++n) { + auto input_slice = input.narrow(0, n * batch_n, batch_n); + auto offset_slice = offset.narrow(0, n * batch_n, batch_n); + auto mask_slice = mask.narrow(0, n * batch_n, batch_n); + auto grad_output_g = grad_output_n.select(0, n); + + auto grad_input_slice = grad_input.narrow(0, n * batch_n, batch_n); + auto grad_offset_slice = grad_offset.narrow(0, n * batch_n, batch_n); + auto grad_mask_slice = grad_mask.narrow(0, n * batch_n, batch_n); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dcnv3_cpu_backward", [&] { + dcnv3::cpu::dcnv3_col2im_cpu( + grad_output_g, input_slice, offset_slice, mask_slice, + grad_input_slice, grad_offset_slice, grad_mask_slice, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, batch_n, + height_in, width_in, height_out, width_out, offset_scale, remove_center); + }); + } + }); + + return {grad_input, grad_offset, grad_mask}; } diff --git a/classification/ops_dcnv3/src/cpu/dcnv3_cpu.h b/classification/ops_dcnv3/src/cpu/dcnv3_cpu.h index d457bcbd..bbf219e5 100644 --- a/classification/ops_dcnv3/src/cpu/dcnv3_cpu.h +++ b/classification/ops_dcnv3/src/cpu/dcnv3_cpu.h @@ -1,11 +1,8 @@ /*! ************************************************************************************************** -* InternImage -* Copyright (c) 2022 OpenGVLab -* Licensed under The MIT License [see LICENSE for details] -************************************************************************************************** -* Modified from -*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 ************************************************************************************************** */ @@ -19,7 +16,7 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const int im2col_step); + const int im2col_step, const int remove_center); std::vector dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, @@ -28,4 +25,4 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const at::Tensor &grad_output, const int im2col_step); + const at::Tensor &grad_output, const int im2col_step, const int remove_center); diff --git a/classification/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp b/classification/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp new file mode 100644 index 00000000..330756b4 --- /dev/null +++ b/classification/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp @@ -0,0 +1,297 @@ +/*! +************************************************************************************************** +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 优化版本:高精度 opmath + Inline Bilinear(无硬编码展开,避免 bug) +************************************************************************************************** +*/ + +#include "dcnv3_im2col_cpu.h" + +#include +#include +#include +#include +#include + +namespace dcnv3 { + namespace cpu { + + template + inline opmath_t dcnv3_im2col_bilinear( + const scalar_t* bottom_data, + const int& height, const int& width, + const int& group, const int& group_channels, + const opmath_t& h, const opmath_t& w, + const int& g, const int& c + ) { + const int h_low = static_cast(std::floor(h)); + const int w_low = static_cast(std::floor(w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = opmath_t(1) - lh; + const opmath_t hw = opmath_t(1) - lw; + + const int w_stride = group * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g * group_channels + c; + + scalar_t v1 = scalar_t(0); + if (h_low >= 0 && w_low >= 0) { + v1 = bottom_data[h_low_ptr_offset + w_low_ptr_offset + base_ptr]; + } + scalar_t v2 = scalar_t(0); + if (h_low >= 0 && w_high <= width - 1) { + v2 = bottom_data[h_low_ptr_offset + w_high_ptr_offset + base_ptr]; + } + scalar_t v3 = scalar_t(0); + if (h_high <= height - 1 && w_low >= 0) { + v3 = bottom_data[h_high_ptr_offset + w_low_ptr_offset + base_ptr]; + } + scalar_t v4 = scalar_t(0); + if (h_high <= height - 1 && w_high <= width - 1) { + v4 = bottom_data[h_high_ptr_offset + w_high_ptr_offset + base_ptr]; + } + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return w1 * static_cast(v1) + w2 * static_cast(v2) + + w3 * static_cast(v3) + w4 * static_cast(v4); + } + + void dcnv3_im2col_cpu( + const at::Tensor& data_im, const at::Tensor& data_offset, + const at::Tensor& data_mask, at::Tensor& data_col, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center + ) { + const int num_kernels = data_col.numel(); + + AT_DISPATCH_FLOATING_TYPES(data_im.scalar_type(), "dcnv3_im2col_cpu", [&] { + using opmath_t = at::opmath_type; + + const scalar_t* data_im_ptr = data_im.data_ptr(); + const scalar_t* data_offset_ptr = data_offset.data_ptr(); + const scalar_t* data_mask_ptr = data_mask.data_ptr(); + scalar_t* data_col_ptr = data_col.data_ptr(); + + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + const int qid_stride = group * group_channels; + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int index = 0; index < num_kernels; ++index) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + scalar_t* data_col_ptr_current = data_col_ptr + index; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const scalar_t* data_im_ptr_current = data_im_ptr + b_col * input_size * qid_stride; + + const opmath_t p0_w_ = static_cast(p0_w) - + ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = static_cast(p0_h) - + ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + opmath_t col = opmath_t(0); + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + if (i == center_w && j == center_h && remove_center) { + continue; + } + const opmath_t offset_w = static_cast(data_offset_ptr[data_loc_w_ptr]); + const opmath_t offset_h = static_cast(data_offset_ptr[data_loc_w_ptr + 1]); + const opmath_t loc_w = p0_w_ + (static_cast(i) * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = p0_h_ + (static_cast(j) * dilation_h + offset_h) * offset_scale; + const opmath_t weight = static_cast(data_mask_ptr[data_weight_ptr]); + + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) { + col += dcnv3_im2col_bilinear( + data_im_ptr_current, height_in, width_in, group, + group_channels, loc_h, loc_w, g_col, c_col) * weight; + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + + *data_col_ptr_current = static_cast(col); + } + }); + } + + void dcnv3_col2im_cpu( + const at::Tensor& grad_col, const at::Tensor& data_im, + const at::Tensor& data_offset, const at::Tensor& data_mask, + at::Tensor& grad_im, at::Tensor& grad_offset, at::Tensor& grad_mask, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center + ) { + const int num_kernels = grad_col.numel(); + + AT_DISPATCH_FLOATING_TYPES(grad_col.scalar_type(), "dcnv3_col2im_cpu", [&] { + using opmath_t = at::opmath_type; + + const scalar_t* grad_col_ptr = grad_col.data_ptr(); + const scalar_t* data_im_ptr = data_im.data_ptr(); + const scalar_t* data_offset_ptr = data_offset.data_ptr(); + const scalar_t* data_mask_ptr = data_mask.data_ptr(); + + scalar_t* grad_im_ptr = grad_im.data_ptr(); + scalar_t* grad_offset_ptr = grad_offset.data_ptr(); + scalar_t* grad_mask_ptr = grad_mask.data_ptr(); + + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + const int qid_stride = group * group_channels; + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int index = 0; index < num_kernels; ++index) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = static_cast(grad_col_ptr[index]); + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t* data_im_ptr_current = data_im_ptr + im_ptr_offset; + scalar_t* grad_im_ptr_current = grad_im_ptr + im_ptr_offset; + + const opmath_t p0_w_ = static_cast(p0_w) - + ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = static_cast(p0_h) - + ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + if (i == center_w && j == center_h && remove_center) { + continue; + } + + const opmath_t offset_w = static_cast(data_offset_ptr[data_loc_w_ptr]); + const opmath_t offset_h = static_cast(data_offset_ptr[data_loc_w_ptr + 1]); + const opmath_t loc_w = p0_w_ + (static_cast(i) * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = p0_h_ + (static_cast(j) * dilation_h + offset_h) * offset_scale; + const opmath_t weight = static_cast(data_mask_ptr[data_weight_ptr]); + + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) { + const int h_low = static_cast(std::floor(loc_h)); + const int w_low = static_cast(std::floor(loc_w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = loc_h - h_low; + const opmath_t lw = loc_w - w_low; + const opmath_t hh = opmath_t(1) - lh; + const opmath_t hw_ = opmath_t(1) - lw; + + const int w_stride = group * group_channels; + const int h_stride = width_in * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g_col * group_channels + c_col; + + const opmath_t w1 = hh * hw_; + const opmath_t w2 = hh * lw; + const opmath_t w3 = lh * hw_; + const opmath_t w4 = lh * lw; + const opmath_t top_grad_im = top_grad * weight; + + opmath_t grad_h_weight = opmath_t(0); + opmath_t grad_w_weight = opmath_t(0); + + scalar_t v1 = scalar_t(0); + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = data_im_ptr_current[ptr1]; + grad_h_weight -= hw_ * static_cast(v1); + grad_w_weight -= hh * static_cast(v1); + grad_im_ptr_current[ptr1] += static_cast(w1 * top_grad_im); + } + scalar_t v2 = scalar_t(0); + if (h_low >= 0 && w_high <= width_in - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = data_im_ptr_current[ptr2]; + grad_h_weight -= lw * static_cast(v2); + grad_w_weight += hh * static_cast(v2); + grad_im_ptr_current[ptr2] += static_cast(w2 * top_grad_im); + } + scalar_t v3 = scalar_t(0); + if (h_high <= height_in - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = data_im_ptr_current[ptr3]; + grad_h_weight += hw_ * static_cast(v3); + grad_w_weight -= lh * static_cast(v3); + grad_im_ptr_current[ptr3] += static_cast(w3 * top_grad_im); + } + scalar_t v4 = scalar_t(0); + if (h_high <= height_in - 1 && w_high <= width_in - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = data_im_ptr_current[ptr4]; + grad_h_weight += lw * static_cast(v4); + grad_w_weight += lh * static_cast(v4); + grad_im_ptr_current[ptr4] += static_cast(w4 * top_grad_im); + } + + const opmath_t val = w1 * static_cast(v1) + w2 * static_cast(v2) + + w3 * static_cast(v3) + w4 * static_cast(v4); + + grad_mask_ptr[data_weight_ptr] += static_cast(top_grad * val); + grad_offset_ptr[data_loc_w_ptr] += static_cast(offset_scale * grad_w_weight * top_grad_im); + grad_offset_ptr[data_loc_w_ptr + 1] += static_cast(offset_scale * grad_h_weight * top_grad_im); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + } + }); + } + } +} diff --git a/classification/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h b/classification/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h new file mode 100644 index 00000000..479df187 --- /dev/null +++ b/classification/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h @@ -0,0 +1,68 @@ +/* +************************************************************************************************** +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 +************************************************************************************************** +*/ + +#pragma once + +#include +#include + +namespace dcnv3 { + namespace cpu { + template + using opmath_t = double; + + template + scalar_t dcnv3_im2col_bilinear(const scalar_t* bottom_data, + const int& height, const int& width, + const int& group, + const int& group_channels, + const double& h, const double& w, + const int& g, const int& c); + + template + void dcnv3_col2im_bilinear( + const scalar_t* bottom_data, const int& height, const int& width, + const int& nheads, const int& group_channels, const double& h, + const double& w, const int& m, const int& c, const double offset_scale, + const double& top_grad, const double& mask, double* grad_im, + double* grad_offset, double* grad_mask); + + template + void dcnv3_col2im_bilinear_gm( + const scalar_t* bottom_data, const int& height, const int& width, + const int& nheads, const int& group_channels, const double& h, + const double& w, const int& m, const int& c, const double offset_scale, + const double& top_grad, const double& mask, double* grad_im, + double* grad_offset, double* grad_mask); + + void dcnv3_im2col_cpu( + const at::Tensor& data_im, const at::Tensor& data_offset, + const at::Tensor& data_mask, at::Tensor& data_col, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center); + + void dcnv3_col2im_cpu( + const at::Tensor& grad_col, const at::Tensor& data_im, + const at::Tensor& data_offset, const at::Tensor& data_mask, + at::Tensor& grad_im, at::Tensor& grad_offset, at::Tensor& grad_mask, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center); + } +} \ No newline at end of file diff --git a/classification/ops_dcnv3/src/dcnv3.h b/classification/ops_dcnv3/src/dcnv3.h index ce4500fa..da9a9c0b 100644 --- a/classification/ops_dcnv3/src/dcnv3.h +++ b/classification/ops_dcnv3/src/dcnv3.h @@ -34,7 +34,10 @@ at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return dcnv3_cpu_forward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, im2col_step, remove_center); } std::vector @@ -55,5 +58,8 @@ dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return dcnv3_cpu_backward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, grad_output, im2col_step, remove_center); } diff --git a/detection/ops_dcnv3/setup.py b/detection/ops_dcnv3/setup.py index e06f2ea9..7951c67f 100644 --- a/detection/ops_dcnv3/setup.py +++ b/detection/ops_dcnv3/setup.py @@ -24,7 +24,7 @@ def get_extensions(): sources = main_file + source_cpu extension = CppExtension - extra_compile_args = {'cxx': []} + extra_compile_args = {'cxx': ['-O2']} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: @@ -37,8 +37,9 @@ def get_extensions(): # "-D__CUDA_NO_HALF_CONVERSIONS__", # "-D__CUDA_NO_HALF2_OPERATORS__", ] + print("CUDA is available, building with CUDA support") else: - raise NotImplementedError('Cuda is not availabel') + print("CUDA is not available, building CPU-only version") sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] @@ -56,7 +57,7 @@ def get_extensions(): setup( name='DCNv3', - version='1.0', + version='1.1', author='InternImage', url='https://github.com/OpenGVLab/InternImage', description= diff --git a/detection/ops_dcnv3/src/cpu/dcnv3_cpu.cpp b/detection/ops_dcnv3/src/cpu/dcnv3_cpu.cpp index a3bddc18..a5b8ec29 100644 --- a/detection/ops_dcnv3/src/cpu/dcnv3_cpu.cpp +++ b/detection/ops_dcnv3/src/cpu/dcnv3_cpu.cpp @@ -1,18 +1,24 @@ /*! ************************************************************************************************** -* InternImage -* Copyright (c) 2022 OpenGVLab -* Licensed under The MIT License [see LICENSE for details] -************************************************************************************************** -* Modified from -*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 ************************************************************************************************** */ #include #include -#include +#include +#include + +#include "cpu/dcnv3_im2col_cpu.h" +#include "dcnv3_cpu.h" +#include "dcnv3_im2col_cpu.h" +#include +#include +#include +#include at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, @@ -21,8 +27,63 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const int im2col_step) { - AT_ERROR("Not implement on cpu"); + const int im2col_step, const int remove_center) { + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor"); + AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto output = + at::zeros({batch, height_out, width_out, group * group_channels}, + input.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch / batch_n, batch_n, height_out, + width_out, group * group_channels}); + + const int num_batches = batch / im2col_step_; + + at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) { + for (int64_t n = start_n; n < end_n; ++n) { + auto input_slice = input.narrow(0, n * batch_n, batch_n); + auto offset_slice = offset.narrow(0, n * batch_n, batch_n); + auto mask_slice = mask.narrow(0, n * batch_n, batch_n); + auto columns = output_n.select(0, n); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dcnv3_cpu_forward", [&] { + dcnv3::cpu::dcnv3_im2col_cpu( + input_slice, offset_slice, mask_slice, columns, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + batch_n, height_in, width_in, height_out, width_out, + offset_scale, remove_center); + }); + } + }); + + return output; } std::vector @@ -32,6 +93,72 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const at::Tensor &grad_output, const int im2col_step) { - AT_ERROR("Not implement on cpu"); + const at::Tensor &grad_output, const int im2col_step, const int remove_center) { + + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor"); + AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor"); + AT_ASSERTM(grad_output.is_cpu(), + "grad_output must be a CPU tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + const int batch_n = im2col_step_; + + auto grad_output_n = + grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, + group, group_channels}); + + const int num_batches = batch / im2col_step_; + + at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) { + for (int64_t n = start_n; n < end_n; ++n) { + auto input_slice = input.narrow(0, n * batch_n, batch_n); + auto offset_slice = offset.narrow(0, n * batch_n, batch_n); + auto mask_slice = mask.narrow(0, n * batch_n, batch_n); + auto grad_output_g = grad_output_n.select(0, n); + + auto grad_input_slice = grad_input.narrow(0, n * batch_n, batch_n); + auto grad_offset_slice = grad_offset.narrow(0, n * batch_n, batch_n); + auto grad_mask_slice = grad_mask.narrow(0, n * batch_n, batch_n); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dcnv3_cpu_backward", [&] { + dcnv3::cpu::dcnv3_col2im_cpu( + grad_output_g, input_slice, offset_slice, mask_slice, + grad_input_slice, grad_offset_slice, grad_mask_slice, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, batch_n, + height_in, width_in, height_out, width_out, offset_scale, remove_center); + }); + } + }); + + return {grad_input, grad_offset, grad_mask}; } diff --git a/detection/ops_dcnv3/src/cpu/dcnv3_cpu.h b/detection/ops_dcnv3/src/cpu/dcnv3_cpu.h index d457bcbd..bbf219e5 100644 --- a/detection/ops_dcnv3/src/cpu/dcnv3_cpu.h +++ b/detection/ops_dcnv3/src/cpu/dcnv3_cpu.h @@ -1,11 +1,8 @@ /*! ************************************************************************************************** -* InternImage -* Copyright (c) 2022 OpenGVLab -* Licensed under The MIT License [see LICENSE for details] -************************************************************************************************** -* Modified from -*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 ************************************************************************************************** */ @@ -19,7 +16,7 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const int im2col_step); + const int im2col_step, const int remove_center); std::vector dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, @@ -28,4 +25,4 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const at::Tensor &grad_output, const int im2col_step); + const at::Tensor &grad_output, const int im2col_step, const int remove_center); diff --git a/detection/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp b/detection/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp new file mode 100644 index 00000000..330756b4 --- /dev/null +++ b/detection/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp @@ -0,0 +1,297 @@ +/*! +************************************************************************************************** +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 优化版本:高精度 opmath + Inline Bilinear(无硬编码展开,避免 bug) +************************************************************************************************** +*/ + +#include "dcnv3_im2col_cpu.h" + +#include +#include +#include +#include +#include + +namespace dcnv3 { + namespace cpu { + + template + inline opmath_t dcnv3_im2col_bilinear( + const scalar_t* bottom_data, + const int& height, const int& width, + const int& group, const int& group_channels, + const opmath_t& h, const opmath_t& w, + const int& g, const int& c + ) { + const int h_low = static_cast(std::floor(h)); + const int w_low = static_cast(std::floor(w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = opmath_t(1) - lh; + const opmath_t hw = opmath_t(1) - lw; + + const int w_stride = group * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g * group_channels + c; + + scalar_t v1 = scalar_t(0); + if (h_low >= 0 && w_low >= 0) { + v1 = bottom_data[h_low_ptr_offset + w_low_ptr_offset + base_ptr]; + } + scalar_t v2 = scalar_t(0); + if (h_low >= 0 && w_high <= width - 1) { + v2 = bottom_data[h_low_ptr_offset + w_high_ptr_offset + base_ptr]; + } + scalar_t v3 = scalar_t(0); + if (h_high <= height - 1 && w_low >= 0) { + v3 = bottom_data[h_high_ptr_offset + w_low_ptr_offset + base_ptr]; + } + scalar_t v4 = scalar_t(0); + if (h_high <= height - 1 && w_high <= width - 1) { + v4 = bottom_data[h_high_ptr_offset + w_high_ptr_offset + base_ptr]; + } + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return w1 * static_cast(v1) + w2 * static_cast(v2) + + w3 * static_cast(v3) + w4 * static_cast(v4); + } + + void dcnv3_im2col_cpu( + const at::Tensor& data_im, const at::Tensor& data_offset, + const at::Tensor& data_mask, at::Tensor& data_col, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center + ) { + const int num_kernels = data_col.numel(); + + AT_DISPATCH_FLOATING_TYPES(data_im.scalar_type(), "dcnv3_im2col_cpu", [&] { + using opmath_t = at::opmath_type; + + const scalar_t* data_im_ptr = data_im.data_ptr(); + const scalar_t* data_offset_ptr = data_offset.data_ptr(); + const scalar_t* data_mask_ptr = data_mask.data_ptr(); + scalar_t* data_col_ptr = data_col.data_ptr(); + + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + const int qid_stride = group * group_channels; + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int index = 0; index < num_kernels; ++index) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + scalar_t* data_col_ptr_current = data_col_ptr + index; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const scalar_t* data_im_ptr_current = data_im_ptr + b_col * input_size * qid_stride; + + const opmath_t p0_w_ = static_cast(p0_w) - + ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = static_cast(p0_h) - + ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + opmath_t col = opmath_t(0); + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + if (i == center_w && j == center_h && remove_center) { + continue; + } + const opmath_t offset_w = static_cast(data_offset_ptr[data_loc_w_ptr]); + const opmath_t offset_h = static_cast(data_offset_ptr[data_loc_w_ptr + 1]); + const opmath_t loc_w = p0_w_ + (static_cast(i) * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = p0_h_ + (static_cast(j) * dilation_h + offset_h) * offset_scale; + const opmath_t weight = static_cast(data_mask_ptr[data_weight_ptr]); + + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) { + col += dcnv3_im2col_bilinear( + data_im_ptr_current, height_in, width_in, group, + group_channels, loc_h, loc_w, g_col, c_col) * weight; + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + + *data_col_ptr_current = static_cast(col); + } + }); + } + + void dcnv3_col2im_cpu( + const at::Tensor& grad_col, const at::Tensor& data_im, + const at::Tensor& data_offset, const at::Tensor& data_mask, + at::Tensor& grad_im, at::Tensor& grad_offset, at::Tensor& grad_mask, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center + ) { + const int num_kernels = grad_col.numel(); + + AT_DISPATCH_FLOATING_TYPES(grad_col.scalar_type(), "dcnv3_col2im_cpu", [&] { + using opmath_t = at::opmath_type; + + const scalar_t* grad_col_ptr = grad_col.data_ptr(); + const scalar_t* data_im_ptr = data_im.data_ptr(); + const scalar_t* data_offset_ptr = data_offset.data_ptr(); + const scalar_t* data_mask_ptr = data_mask.data_ptr(); + + scalar_t* grad_im_ptr = grad_im.data_ptr(); + scalar_t* grad_offset_ptr = grad_offset.data_ptr(); + scalar_t* grad_mask_ptr = grad_mask.data_ptr(); + + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + const int qid_stride = group * group_channels; + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int index = 0; index < num_kernels; ++index) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = static_cast(grad_col_ptr[index]); + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t* data_im_ptr_current = data_im_ptr + im_ptr_offset; + scalar_t* grad_im_ptr_current = grad_im_ptr + im_ptr_offset; + + const opmath_t p0_w_ = static_cast(p0_w) - + ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = static_cast(p0_h) - + ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + if (i == center_w && j == center_h && remove_center) { + continue; + } + + const opmath_t offset_w = static_cast(data_offset_ptr[data_loc_w_ptr]); + const opmath_t offset_h = static_cast(data_offset_ptr[data_loc_w_ptr + 1]); + const opmath_t loc_w = p0_w_ + (static_cast(i) * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = p0_h_ + (static_cast(j) * dilation_h + offset_h) * offset_scale; + const opmath_t weight = static_cast(data_mask_ptr[data_weight_ptr]); + + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) { + const int h_low = static_cast(std::floor(loc_h)); + const int w_low = static_cast(std::floor(loc_w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = loc_h - h_low; + const opmath_t lw = loc_w - w_low; + const opmath_t hh = opmath_t(1) - lh; + const opmath_t hw_ = opmath_t(1) - lw; + + const int w_stride = group * group_channels; + const int h_stride = width_in * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g_col * group_channels + c_col; + + const opmath_t w1 = hh * hw_; + const opmath_t w2 = hh * lw; + const opmath_t w3 = lh * hw_; + const opmath_t w4 = lh * lw; + const opmath_t top_grad_im = top_grad * weight; + + opmath_t grad_h_weight = opmath_t(0); + opmath_t grad_w_weight = opmath_t(0); + + scalar_t v1 = scalar_t(0); + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = data_im_ptr_current[ptr1]; + grad_h_weight -= hw_ * static_cast(v1); + grad_w_weight -= hh * static_cast(v1); + grad_im_ptr_current[ptr1] += static_cast(w1 * top_grad_im); + } + scalar_t v2 = scalar_t(0); + if (h_low >= 0 && w_high <= width_in - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = data_im_ptr_current[ptr2]; + grad_h_weight -= lw * static_cast(v2); + grad_w_weight += hh * static_cast(v2); + grad_im_ptr_current[ptr2] += static_cast(w2 * top_grad_im); + } + scalar_t v3 = scalar_t(0); + if (h_high <= height_in - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = data_im_ptr_current[ptr3]; + grad_h_weight += hw_ * static_cast(v3); + grad_w_weight -= lh * static_cast(v3); + grad_im_ptr_current[ptr3] += static_cast(w3 * top_grad_im); + } + scalar_t v4 = scalar_t(0); + if (h_high <= height_in - 1 && w_high <= width_in - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = data_im_ptr_current[ptr4]; + grad_h_weight += lw * static_cast(v4); + grad_w_weight += lh * static_cast(v4); + grad_im_ptr_current[ptr4] += static_cast(w4 * top_grad_im); + } + + const opmath_t val = w1 * static_cast(v1) + w2 * static_cast(v2) + + w3 * static_cast(v3) + w4 * static_cast(v4); + + grad_mask_ptr[data_weight_ptr] += static_cast(top_grad * val); + grad_offset_ptr[data_loc_w_ptr] += static_cast(offset_scale * grad_w_weight * top_grad_im); + grad_offset_ptr[data_loc_w_ptr + 1] += static_cast(offset_scale * grad_h_weight * top_grad_im); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + } + }); + } + } +} diff --git a/detection/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h b/detection/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h new file mode 100644 index 00000000..479df187 --- /dev/null +++ b/detection/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h @@ -0,0 +1,68 @@ +/* +************************************************************************************************** +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 +************************************************************************************************** +*/ + +#pragma once + +#include +#include + +namespace dcnv3 { + namespace cpu { + template + using opmath_t = double; + + template + scalar_t dcnv3_im2col_bilinear(const scalar_t* bottom_data, + const int& height, const int& width, + const int& group, + const int& group_channels, + const double& h, const double& w, + const int& g, const int& c); + + template + void dcnv3_col2im_bilinear( + const scalar_t* bottom_data, const int& height, const int& width, + const int& nheads, const int& group_channels, const double& h, + const double& w, const int& m, const int& c, const double offset_scale, + const double& top_grad, const double& mask, double* grad_im, + double* grad_offset, double* grad_mask); + + template + void dcnv3_col2im_bilinear_gm( + const scalar_t* bottom_data, const int& height, const int& width, + const int& nheads, const int& group_channels, const double& h, + const double& w, const int& m, const int& c, const double offset_scale, + const double& top_grad, const double& mask, double* grad_im, + double* grad_offset, double* grad_mask); + + void dcnv3_im2col_cpu( + const at::Tensor& data_im, const at::Tensor& data_offset, + const at::Tensor& data_mask, at::Tensor& data_col, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center); + + void dcnv3_col2im_cpu( + const at::Tensor& grad_col, const at::Tensor& data_im, + const at::Tensor& data_offset, const at::Tensor& data_mask, + at::Tensor& grad_im, at::Tensor& grad_offset, at::Tensor& grad_mask, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center); + } +} \ No newline at end of file diff --git a/detection/ops_dcnv3/src/dcnv3.h b/detection/ops_dcnv3/src/dcnv3.h index 029648e1..da9a9c0b 100644 --- a/detection/ops_dcnv3/src/dcnv3.h +++ b/detection/ops_dcnv3/src/dcnv3.h @@ -23,18 +23,21 @@ at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, - const float offset_scale, const int im2col_step) { + const float offset_scale, const int im2col_step, const int remove_center) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, - offset_scale, im2col_step); + offset_scale, im2col_step, remove_center); #else AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return dcnv3_cpu_forward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, im2col_step, remove_center); } std::vector @@ -44,16 +47,19 @@ dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, const at::Tensor &grad_output, - const int im2col_step) { + const int im2col_step, const int remove_center) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, - offset_scale, grad_output, im2col_step); + offset_scale, grad_output, im2col_step, remove_center); #else AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return dcnv3_cpu_backward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, grad_output, im2col_step, remove_center); } diff --git a/segmentation/ops_dcnv3/setup.py b/segmentation/ops_dcnv3/setup.py index e06f2ea9..7951c67f 100644 --- a/segmentation/ops_dcnv3/setup.py +++ b/segmentation/ops_dcnv3/setup.py @@ -24,7 +24,7 @@ def get_extensions(): sources = main_file + source_cpu extension = CppExtension - extra_compile_args = {'cxx': []} + extra_compile_args = {'cxx': ['-O2']} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: @@ -37,8 +37,9 @@ def get_extensions(): # "-D__CUDA_NO_HALF_CONVERSIONS__", # "-D__CUDA_NO_HALF2_OPERATORS__", ] + print("CUDA is available, building with CUDA support") else: - raise NotImplementedError('Cuda is not availabel') + print("CUDA is not available, building CPU-only version") sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] @@ -56,7 +57,7 @@ def get_extensions(): setup( name='DCNv3', - version='1.0', + version='1.1', author='InternImage', url='https://github.com/OpenGVLab/InternImage', description= diff --git a/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.cpp b/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.cpp index a3bddc18..a5b8ec29 100644 --- a/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.cpp +++ b/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.cpp @@ -1,18 +1,24 @@ /*! ************************************************************************************************** -* InternImage -* Copyright (c) 2022 OpenGVLab -* Licensed under The MIT License [see LICENSE for details] -************************************************************************************************** -* Modified from -*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 ************************************************************************************************** */ #include #include -#include +#include +#include + +#include "cpu/dcnv3_im2col_cpu.h" +#include "dcnv3_cpu.h" +#include "dcnv3_im2col_cpu.h" +#include +#include +#include +#include at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, @@ -21,8 +27,63 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const int im2col_step) { - AT_ERROR("Not implement on cpu"); + const int im2col_step, const int remove_center) { + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor"); + AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto output = + at::zeros({batch, height_out, width_out, group * group_channels}, + input.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch / batch_n, batch_n, height_out, + width_out, group * group_channels}); + + const int num_batches = batch / im2col_step_; + + at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) { + for (int64_t n = start_n; n < end_n; ++n) { + auto input_slice = input.narrow(0, n * batch_n, batch_n); + auto offset_slice = offset.narrow(0, n * batch_n, batch_n); + auto mask_slice = mask.narrow(0, n * batch_n, batch_n); + auto columns = output_n.select(0, n); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dcnv3_cpu_forward", [&] { + dcnv3::cpu::dcnv3_im2col_cpu( + input_slice, offset_slice, mask_slice, columns, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + batch_n, height_in, width_in, height_out, width_out, + offset_scale, remove_center); + }); + } + }); + + return output; } std::vector @@ -32,6 +93,72 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const at::Tensor &grad_output, const int im2col_step) { - AT_ERROR("Not implement on cpu"); + const at::Tensor &grad_output, const int im2col_step, const int remove_center) { + + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor"); + AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor"); + AT_ASSERTM(grad_output.is_cpu(), + "grad_output must be a CPU tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + const int batch_n = im2col_step_; + + auto grad_output_n = + grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, + group, group_channels}); + + const int num_batches = batch / im2col_step_; + + at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) { + for (int64_t n = start_n; n < end_n; ++n) { + auto input_slice = input.narrow(0, n * batch_n, batch_n); + auto offset_slice = offset.narrow(0, n * batch_n, batch_n); + auto mask_slice = mask.narrow(0, n * batch_n, batch_n); + auto grad_output_g = grad_output_n.select(0, n); + + auto grad_input_slice = grad_input.narrow(0, n * batch_n, batch_n); + auto grad_offset_slice = grad_offset.narrow(0, n * batch_n, batch_n); + auto grad_mask_slice = grad_mask.narrow(0, n * batch_n, batch_n); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dcnv3_cpu_backward", [&] { + dcnv3::cpu::dcnv3_col2im_cpu( + grad_output_g, input_slice, offset_slice, mask_slice, + grad_input_slice, grad_offset_slice, grad_mask_slice, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, batch_n, + height_in, width_in, height_out, width_out, offset_scale, remove_center); + }); + } + }); + + return {grad_input, grad_offset, grad_mask}; } diff --git a/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.h b/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.h index d457bcbd..bbf219e5 100644 --- a/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.h +++ b/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.h @@ -1,11 +1,8 @@ /*! ************************************************************************************************** -* InternImage -* Copyright (c) 2022 OpenGVLab -* Licensed under The MIT License [see LICENSE for details] -************************************************************************************************** -* Modified from -*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 ************************************************************************************************** */ @@ -19,7 +16,7 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const int im2col_step); + const int im2col_step, const int remove_center); std::vector dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, @@ -28,4 +25,4 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, - const at::Tensor &grad_output, const int im2col_step); + const at::Tensor &grad_output, const int im2col_step, const int remove_center); diff --git a/segmentation/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp b/segmentation/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp new file mode 100644 index 00000000..330756b4 --- /dev/null +++ b/segmentation/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.cpp @@ -0,0 +1,297 @@ +/*! +************************************************************************************************** +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 优化版本:高精度 opmath + Inline Bilinear(无硬编码展开,避免 bug) +************************************************************************************************** +*/ + +#include "dcnv3_im2col_cpu.h" + +#include +#include +#include +#include +#include + +namespace dcnv3 { + namespace cpu { + + template + inline opmath_t dcnv3_im2col_bilinear( + const scalar_t* bottom_data, + const int& height, const int& width, + const int& group, const int& group_channels, + const opmath_t& h, const opmath_t& w, + const int& g, const int& c + ) { + const int h_low = static_cast(std::floor(h)); + const int w_low = static_cast(std::floor(w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = opmath_t(1) - lh; + const opmath_t hw = opmath_t(1) - lw; + + const int w_stride = group * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g * group_channels + c; + + scalar_t v1 = scalar_t(0); + if (h_low >= 0 && w_low >= 0) { + v1 = bottom_data[h_low_ptr_offset + w_low_ptr_offset + base_ptr]; + } + scalar_t v2 = scalar_t(0); + if (h_low >= 0 && w_high <= width - 1) { + v2 = bottom_data[h_low_ptr_offset + w_high_ptr_offset + base_ptr]; + } + scalar_t v3 = scalar_t(0); + if (h_high <= height - 1 && w_low >= 0) { + v3 = bottom_data[h_high_ptr_offset + w_low_ptr_offset + base_ptr]; + } + scalar_t v4 = scalar_t(0); + if (h_high <= height - 1 && w_high <= width - 1) { + v4 = bottom_data[h_high_ptr_offset + w_high_ptr_offset + base_ptr]; + } + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return w1 * static_cast(v1) + w2 * static_cast(v2) + + w3 * static_cast(v3) + w4 * static_cast(v4); + } + + void dcnv3_im2col_cpu( + const at::Tensor& data_im, const at::Tensor& data_offset, + const at::Tensor& data_mask, at::Tensor& data_col, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center + ) { + const int num_kernels = data_col.numel(); + + AT_DISPATCH_FLOATING_TYPES(data_im.scalar_type(), "dcnv3_im2col_cpu", [&] { + using opmath_t = at::opmath_type; + + const scalar_t* data_im_ptr = data_im.data_ptr(); + const scalar_t* data_offset_ptr = data_offset.data_ptr(); + const scalar_t* data_mask_ptr = data_mask.data_ptr(); + scalar_t* data_col_ptr = data_col.data_ptr(); + + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + const int qid_stride = group * group_channels; + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int index = 0; index < num_kernels; ++index) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + scalar_t* data_col_ptr_current = data_col_ptr + index; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const scalar_t* data_im_ptr_current = data_im_ptr + b_col * input_size * qid_stride; + + const opmath_t p0_w_ = static_cast(p0_w) - + ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = static_cast(p0_h) - + ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + opmath_t col = opmath_t(0); + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + if (i == center_w && j == center_h && remove_center) { + continue; + } + const opmath_t offset_w = static_cast(data_offset_ptr[data_loc_w_ptr]); + const opmath_t offset_h = static_cast(data_offset_ptr[data_loc_w_ptr + 1]); + const opmath_t loc_w = p0_w_ + (static_cast(i) * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = p0_h_ + (static_cast(j) * dilation_h + offset_h) * offset_scale; + const opmath_t weight = static_cast(data_mask_ptr[data_weight_ptr]); + + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) { + col += dcnv3_im2col_bilinear( + data_im_ptr_current, height_in, width_in, group, + group_channels, loc_h, loc_w, g_col, c_col) * weight; + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + + *data_col_ptr_current = static_cast(col); + } + }); + } + + void dcnv3_col2im_cpu( + const at::Tensor& grad_col, const at::Tensor& data_im, + const at::Tensor& data_offset, const at::Tensor& data_mask, + at::Tensor& grad_im, at::Tensor& grad_offset, at::Tensor& grad_mask, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center + ) { + const int num_kernels = grad_col.numel(); + + AT_DISPATCH_FLOATING_TYPES(grad_col.scalar_type(), "dcnv3_col2im_cpu", [&] { + using opmath_t = at::opmath_type; + + const scalar_t* grad_col_ptr = grad_col.data_ptr(); + const scalar_t* data_im_ptr = data_im.data_ptr(); + const scalar_t* data_offset_ptr = data_offset.data_ptr(); + const scalar_t* data_mask_ptr = data_mask.data_ptr(); + + scalar_t* grad_im_ptr = grad_im.data_ptr(); + scalar_t* grad_offset_ptr = grad_offset.data_ptr(); + scalar_t* grad_mask_ptr = grad_mask.data_ptr(); + + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + const int qid_stride = group * group_channels; + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int index = 0; index < num_kernels; ++index) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = static_cast(grad_col_ptr[index]); + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t* data_im_ptr_current = data_im_ptr + im_ptr_offset; + scalar_t* grad_im_ptr_current = grad_im_ptr + im_ptr_offset; + + const opmath_t p0_w_ = static_cast(p0_w) - + ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = static_cast(p0_h) - + ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + if (i == center_w && j == center_h && remove_center) { + continue; + } + + const opmath_t offset_w = static_cast(data_offset_ptr[data_loc_w_ptr]); + const opmath_t offset_h = static_cast(data_offset_ptr[data_loc_w_ptr + 1]); + const opmath_t loc_w = p0_w_ + (static_cast(i) * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = p0_h_ + (static_cast(j) * dilation_h + offset_h) * offset_scale; + const opmath_t weight = static_cast(data_mask_ptr[data_weight_ptr]); + + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) { + const int h_low = static_cast(std::floor(loc_h)); + const int w_low = static_cast(std::floor(loc_w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = loc_h - h_low; + const opmath_t lw = loc_w - w_low; + const opmath_t hh = opmath_t(1) - lh; + const opmath_t hw_ = opmath_t(1) - lw; + + const int w_stride = group * group_channels; + const int h_stride = width_in * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g_col * group_channels + c_col; + + const opmath_t w1 = hh * hw_; + const opmath_t w2 = hh * lw; + const opmath_t w3 = lh * hw_; + const opmath_t w4 = lh * lw; + const opmath_t top_grad_im = top_grad * weight; + + opmath_t grad_h_weight = opmath_t(0); + opmath_t grad_w_weight = opmath_t(0); + + scalar_t v1 = scalar_t(0); + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = data_im_ptr_current[ptr1]; + grad_h_weight -= hw_ * static_cast(v1); + grad_w_weight -= hh * static_cast(v1); + grad_im_ptr_current[ptr1] += static_cast(w1 * top_grad_im); + } + scalar_t v2 = scalar_t(0); + if (h_low >= 0 && w_high <= width_in - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = data_im_ptr_current[ptr2]; + grad_h_weight -= lw * static_cast(v2); + grad_w_weight += hh * static_cast(v2); + grad_im_ptr_current[ptr2] += static_cast(w2 * top_grad_im); + } + scalar_t v3 = scalar_t(0); + if (h_high <= height_in - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = data_im_ptr_current[ptr3]; + grad_h_weight += hw_ * static_cast(v3); + grad_w_weight -= lh * static_cast(v3); + grad_im_ptr_current[ptr3] += static_cast(w3 * top_grad_im); + } + scalar_t v4 = scalar_t(0); + if (h_high <= height_in - 1 && w_high <= width_in - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = data_im_ptr_current[ptr4]; + grad_h_weight += lw * static_cast(v4); + grad_w_weight += lh * static_cast(v4); + grad_im_ptr_current[ptr4] += static_cast(w4 * top_grad_im); + } + + const opmath_t val = w1 * static_cast(v1) + w2 * static_cast(v2) + + w3 * static_cast(v3) + w4 * static_cast(v4); + + grad_mask_ptr[data_weight_ptr] += static_cast(top_grad * val); + grad_offset_ptr[data_loc_w_ptr] += static_cast(offset_scale * grad_w_weight * top_grad_im); + grad_offset_ptr[data_loc_w_ptr + 1] += static_cast(offset_scale * grad_h_weight * top_grad_im); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + } + }); + } + } +} diff --git a/segmentation/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h b/segmentation/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h new file mode 100644 index 00000000..479df187 --- /dev/null +++ b/segmentation/ops_dcnv3/src/cpu/dcnv3_im2col_cpu.h @@ -0,0 +1,68 @@ +/* +************************************************************************************************** +* MIT 开源协议 +* 版权所有 (c) 2024 hxaxd +* 详见 LICENSE 文件 +************************************************************************************************** +*/ + +#pragma once + +#include +#include + +namespace dcnv3 { + namespace cpu { + template + using opmath_t = double; + + template + scalar_t dcnv3_im2col_bilinear(const scalar_t* bottom_data, + const int& height, const int& width, + const int& group, + const int& group_channels, + const double& h, const double& w, + const int& g, const int& c); + + template + void dcnv3_col2im_bilinear( + const scalar_t* bottom_data, const int& height, const int& width, + const int& nheads, const int& group_channels, const double& h, + const double& w, const int& m, const int& c, const double offset_scale, + const double& top_grad, const double& mask, double* grad_im, + double* grad_offset, double* grad_mask); + + template + void dcnv3_col2im_bilinear_gm( + const scalar_t* bottom_data, const int& height, const int& width, + const int& nheads, const int& group_channels, const double& h, + const double& w, const int& m, const int& c, const double offset_scale, + const double& top_grad, const double& mask, double* grad_im, + double* grad_offset, double* grad_mask); + + void dcnv3_im2col_cpu( + const at::Tensor& data_im, const at::Tensor& data_offset, + const at::Tensor& data_mask, at::Tensor& data_col, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center); + + void dcnv3_col2im_cpu( + const at::Tensor& grad_col, const at::Tensor& data_im, + const at::Tensor& data_offset, const at::Tensor& data_mask, + at::Tensor& grad_im, at::Tensor& grad_offset, at::Tensor& grad_mask, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, const int width_in, + const int height_out, const int width_out, + const double offset_scale, const int remove_center); + } +} \ No newline at end of file diff --git a/segmentation/ops_dcnv3/src/dcnv3.h b/segmentation/ops_dcnv3/src/dcnv3.h index 029648e1..da9a9c0b 100644 --- a/segmentation/ops_dcnv3/src/dcnv3.h +++ b/segmentation/ops_dcnv3/src/dcnv3.h @@ -23,18 +23,21 @@ at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, - const float offset_scale, const int im2col_step) { + const float offset_scale, const int im2col_step, const int remove_center) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, - offset_scale, im2col_step); + offset_scale, im2col_step, remove_center); #else AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return dcnv3_cpu_forward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, im2col_step, remove_center); } std::vector @@ -44,16 +47,19 @@ dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int group_channels, const float offset_scale, const at::Tensor &grad_output, - const int im2col_step) { + const int im2col_step, const int remove_center) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, - offset_scale, grad_output, im2col_step); + offset_scale, grad_output, im2col_step, remove_center); #else AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return dcnv3_cpu_backward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, grad_output, im2col_step, remove_center); }