Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions classification/ops_dcnv3/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_extensions():

sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {'cxx': []}
extra_compile_args = {'cxx': ['-O2']}
define_macros = []

if torch.cuda.is_available() and CUDA_HOME is not None:
Expand All @@ -37,8 +37,9 @@ def get_extensions():
# "-D__CUDA_NO_HALF_CONVERSIONS__",
# "-D__CUDA_NO_HALF2_OPERATORS__",
]
print("CUDA is available, building with CUDA support")
else:
raise NotImplementedError('Cuda is not availabel')
print("CUDA is not available, building CPU-only version")

sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
Expand Down
149 changes: 138 additions & 11 deletions classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
* MIT 开源协议
* 版权所有 (c) 2024 hxaxd
* 详见 LICENSE 文件
**************************************************************************************************
*/

#include <vector>

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>

#include "cpu/dcnv3_im2col_cpu.h"
#include "dcnv3_cpu.h"
#include "dcnv3_im2col_cpu.h"
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <iostream>
#include <cmath>

at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h,
Expand All @@ -21,8 +27,63 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const int im2col_step) {
AT_ERROR("Not implement on cpu");
const int im2col_step, const int remove_center) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor");
AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor");

const int batch = input.size(0);
const int height_in = input.size(1);
const int width_in = input.size(2);
const int channels = input.size(3);
const int height_out =
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int width_out =
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int im2col_step_ = std::min(batch, im2col_step);

AT_ASSERTM(batch % im2col_step_ == 0,
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
AT_ASSERTM(
channels == (group * group_channels),
"Input channels and group times group channels wont match: (%d vs %d).",
channels, group * group_channels);

auto output =
at::zeros({batch, height_out, width_out, group * group_channels},
input.options());

const int batch_n = im2col_step_;
auto output_n = output.view({batch / batch_n, batch_n, height_out,
width_out, group * group_channels});

const int num_batches = batch / im2col_step_;

at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) {
for (int64_t n = start_n; n < end_n; ++n) {
auto input_slice = input.narrow(0, n * batch_n, batch_n);
auto offset_slice = offset.narrow(0, n * batch_n, batch_n);
auto mask_slice = mask.narrow(0, n * batch_n, batch_n);
auto columns = output_n.select(0, n);

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dcnv3_cpu_forward", [&] {
dcnv3::cpu::dcnv3_im2col_cpu(
input_slice, offset_slice, mask_slice, columns,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels,
batch_n, height_in, width_in, height_out, width_out,
offset_scale, remove_center);
});
}
});

return output;
}

std::vector<at::Tensor>
Expand All @@ -32,6 +93,72 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step) {
AT_ERROR("Not implement on cpu");
const at::Tensor &grad_output, const int im2col_step, const int remove_center) {

AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(input.is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(offset.is_cpu(), "offset must be a CPU tensor");
AT_ASSERTM(mask.is_cpu(), "mask must be a CPU tensor");
AT_ASSERTM(grad_output.is_cpu(),
"grad_output must be a CPU tensor");

const int batch = input.size(0);
const int height_in = input.size(1);
const int width_in = input.size(2);
const int channels = input.size(3);
const int height_out =
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int width_out =
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int im2col_step_ = std::min(batch, im2col_step);

AT_ASSERTM(batch % im2col_step_ == 0,
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
AT_ASSERTM(
channels == (group * group_channels),
"Input channels and group times group channels wont match: (%d vs %d).",
channels, group * group_channels);

auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);

const int batch_n = im2col_step_;

auto grad_output_n =
grad_output.view({batch / im2col_step_, batch_n, height_out * width_out,
group, group_channels});

const int num_batches = batch / im2col_step_;

at::parallel_for(0, num_batches, 0, [&](int64_t start_n, int64_t end_n) {
for (int64_t n = start_n; n < end_n; ++n) {
auto input_slice = input.narrow(0, n * batch_n, batch_n);
auto offset_slice = offset.narrow(0, n * batch_n, batch_n);
auto mask_slice = mask.narrow(0, n * batch_n, batch_n);
auto grad_output_g = grad_output_n.select(0, n);

auto grad_input_slice = grad_input.narrow(0, n * batch_n, batch_n);
auto grad_offset_slice = grad_offset.narrow(0, n * batch_n, batch_n);
auto grad_mask_slice = grad_mask.narrow(0, n * batch_n, batch_n);

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dcnv3_cpu_backward", [&] {
dcnv3::cpu::dcnv3_col2im_cpu(
grad_output_g, input_slice, offset_slice, mask_slice,
grad_input_slice, grad_offset_slice, grad_mask_slice,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, batch_n,
height_in, width_in, height_out, width_out, offset_scale, remove_center);
});
}
});

return {grad_input, grad_offset, grad_mask};
}
13 changes: 5 additions & 8 deletions classification/ops_dcnv3/src/cpu/dcnv3_cpu.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
* MIT 开源协议
* 版权所有 (c) 2024 hxaxd
* 详见 LICENSE 文件
**************************************************************************************************
*/

Expand All @@ -19,7 +16,7 @@ at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const int im2col_step);
const int im2col_step, const int remove_center);

std::vector<at::Tensor>
dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
Expand All @@ -28,4 +25,4 @@ dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step);
const at::Tensor &grad_output, const int im2col_step, const int remove_center);
Loading