Skip to content

Commit 78b7873

Browse files
committed
issue/360: success conv
1 parent 90bc9a7 commit 78b7873

File tree

2 files changed

+30
-73
lines changed

2 files changed

+30
-73
lines changed

src/infiniop/ops/conv/kunlun/conv_kunlun.cc

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -73,39 +73,16 @@ infiniStatus_t conv_kernel(
7373
int64_t stride = (int64_t)info.stride_info(0);
7474
std::initializer_list<int64_t> pad = {(int64_t)info.pad_info(0)};
7575
int64_t dilation = (int64_t)info.dilation_info(0);
76-
printf("x_shape:(%ld, %ld, %ld)\n", info.batch(), info.in_channels(), info.input_dim(0));
77-
printf("kernel_dim:(%ld)\n", ksize);
78-
printf("stride:(%ld)\n", stride);
79-
printf("pad:(%ld)\n", (int64_t)info.pad_info(0));
80-
printf("dilation:(%ld)\n", dilation);
81-
std::cout << "ndim: " << info.ndim() << " bias_size: " << bias_size << std::endl;
76+
8277
if (dtype == INFINI_DTYPE_F16) {
83-
// float16 *host_x, *host_w, *host_bias;
84-
// host_x = (float16 *)malloc((int)info.batch() * (int)info.in_channels() * (int)info.input_dim(0) * sizeof(float16));
85-
// host_w = (float16 *)malloc((int)bias_size * (int)info.in_channels() * (int)info.kernel_dim(0) * sizeof(float16));
86-
// host_bias = (float16 *)malloc((int)bias_size * sizeof(float16));
87-
// xpu_memcpy(host_x, x, (int)info.batch() * (int)info.in_channels() * (int)info.input_dim(0) * sizeof(float16), XPU_DEVICE_TO_HOST);
88-
// xpu_memcpy(host_w, w, (int)bias_size * (int)info.in_channels() * (int)info.kernel_dim(0) * sizeof(float16), XPU_DEVICE_TO_HOST);
89-
// xpu_memcpy(host_bias, bias, (int)bias_size * sizeof(float16), XPU_DEVICE_TO_HOST);
90-
// for (int i = 0; i < (int)info.batch() * (int)info.in_channels() * (int)info.input_dim(0); i++) {
91-
// printf("%.4f ", static_cast<float>(host_x[i]));
92-
// }
93-
// printf("\n");
94-
// for (int i = 0; i < (int)bias_size * (int)info.in_channels() * (int)info.kernel_dim(0); i++) {
95-
// printf("%.4f ", static_cast<float>(host_w[i]));
96-
// }
97-
// printf("\n");
98-
// for (int i = 0; i < (int)bias_size; i++) {
99-
// printf("%.4f ", static_cast<float>(host_bias[i]));
100-
// }
101-
// printf("\n");
78+
10279
if (bias_size > 0) {
10380
CHECK_STATUS(internal->useXdnn(
10481
(kunlunStream_t)stream,
10582
[&](xdnnHandle_t handle) {
10683
CHECK_KUNLUN((xdnn::cast<float16, float>(handle, (float16 *)bias, bias_F32, bias_size)));
10784
CHECK_KUNLUN((xdnn::conv1d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
108-
(int64_t)info.kernel_dim(0), ksize,
85+
(int64_t)info.out_channels(), ksize,
10986
stride, pad,
11087
dilation, 1, nullptr,
11188
nullptr, nullptr, true, bias_F32,
@@ -118,7 +95,7 @@ infiniStatus_t conv_kernel(
11895
(kunlunStream_t)stream,
11996
[&](xdnnHandle_t handle) {
12097
CHECK_KUNLUN((xdnn::conv1d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
121-
(int64_t)info.kernel_dim(0), ksize,
98+
(int64_t)info.out_channels(), ksize,
12299
stride, pad,
123100
dilation, 1, nullptr,
124101
nullptr, nullptr, true, nullptr,
@@ -134,7 +111,7 @@ infiniStatus_t conv_kernel(
134111
(kunlunStream_t)stream,
135112
[&](xdnnHandle_t handle) {
136113
CHECK_KUNLUN((xdnn::conv1d_fusion<float, float, float, int16_t>(handle, (float *)x, (float *)w, (float *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
137-
(int64_t)info.kernel_dim(0), ksize,
114+
(int64_t)info.out_channels(), ksize,
138115
stride, pad,
139116
dilation, 1, nullptr,
140117
nullptr, nullptr, true, (float *)bias,
@@ -156,20 +133,15 @@ infiniStatus_t conv_kernel(
156133
(int64_t)info.pad_info(1),
157134
(int64_t)info.pad_info(1)};
158135
std::vector<int64_t> dilation = {(int64_t)info.dilation_info(0), (int64_t)info.dilation_info(1)};
159-
printf("x_shape:(%ld, %ld, %ld, %ld)\n", info.batch(), info.in_channels(), info.input_dim(0), info.input_dim(1));
160-
printf("kernel_dim:(%ld, %ld)\n", ksize[0], ksize[1]);
161-
printf("stride:(%ld, %ld)\n", stride[0], stride[1]);
162-
printf("pad:(%ld, %ld)\n", pad[0], pad[1]);
163-
printf("dilation:(%ld, %ld)\n", dilation[0], dilation[1]);
164-
std::cout << "ndim: " << info.ndim() << " bias_size: " << bias_size << std::endl;
136+
165137
if (dtype == INFINI_DTYPE_F16) {
166138
if (bias_size > 0) {
167139
CHECK_STATUS(internal->useXdnn(
168140
(kunlunStream_t)stream,
169141
[&](xdnnHandle_t handle) {
170142
CHECK_KUNLUN((xdnn::cast<float16, float>(handle, (float16 *)bias, bias_F32, bias_size)));
171143
CHECK_KUNLUN((xdnn::conv2d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
172-
(int64_t)info.input_dim(1), (int64_t)info.kernel_dim(0), ksize,
144+
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
173145
stride, pad,
174146
dilation, 1, nullptr,
175147
nullptr, nullptr, true, bias_F32,
@@ -182,7 +154,7 @@ infiniStatus_t conv_kernel(
182154
(kunlunStream_t)stream,
183155
[&](xdnnHandle_t handle) {
184156
CHECK_KUNLUN((xdnn::conv2d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
185-
(int64_t)info.input_dim(1), (int64_t)info.kernel_dim(0), ksize,
157+
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
186158
stride, pad,
187159
dilation, 1, nullptr,
188160
nullptr, nullptr, true, nullptr,
@@ -198,7 +170,7 @@ infiniStatus_t conv_kernel(
198170
(kunlunStream_t)stream,
199171
[&](xdnnHandle_t handle) {
200172
CHECK_KUNLUN((xdnn::conv2d_fusion<float, float, float, int16_t>(handle, (float *)x, (float *)w, (float *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
201-
(int64_t)info.input_dim(1), (int64_t)info.kernel_dim(0), ksize,
173+
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
202174
stride, pad,
203175
dilation, 1, nullptr,
204176
nullptr, nullptr, true, (float *)bias,
@@ -217,20 +189,14 @@ infiniStatus_t conv_kernel(
217189
std::vector<int64_t> pad = {(int64_t)info.pad_info(0), (int64_t)info.pad_info(1), (int64_t)info.pad_info(2)};
218190
std::vector<int64_t> dilation = {(int64_t)info.dilation_info(0), (int64_t)info.dilation_info(1), (int64_t)info.dilation_info(2)};
219191

220-
printf("x_shape:(%ld, %ld, %ld, %ld, %ld)\n", info.batch(), info.in_channels(), info.input_dim(0), info.input_dim(1), info.input_dim(2));
221-
printf("kernel_dim:(%ld, %ld, %ld)\n", ksize[0], ksize[1], ksize[2]);
222-
printf("stride:(%ld, %ld, %ld)\n", stride[0], stride[1], stride[2]);
223-
printf("pad:(%ld, %ld, %ld)\n", pad[0], pad[1], pad[2]);
224-
printf("dilation:(%ld, %ld, %ld)\n", dilation[0], dilation[1], dilation[2]);
225-
std::cout << "ndim: " << info.ndim() << " bias_size: " << bias_size << std::endl;
226192
if (dtype == INFINI_DTYPE_F16) {
227193
if (bias_size > 0) {
228194
CHECK_STATUS(internal->useXdnn(
229195
(kunlunStream_t)stream,
230196
[&](xdnnHandle_t handle) {
231197
CHECK_KUNLUN((xdnn::cast<float16, float>(handle, (float16 *)bias, bias_F32, bias_size)));
232198
CHECK_KUNLUN((xdnn::conv3d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
233-
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.kernel_dim(0), ksize,
199+
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
234200
stride, pad,
235201
dilation, 1, nullptr,
236202
nullptr, nullptr, true, bias_F32,
@@ -243,7 +209,7 @@ infiniStatus_t conv_kernel(
243209
(kunlunStream_t)stream,
244210
[&](xdnnHandle_t handle) {
245211
CHECK_KUNLUN((xdnn::conv3d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
246-
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.kernel_dim(0), ksize,
212+
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
247213
stride, pad,
248214
dilation, 1, nullptr,
249215
nullptr, nullptr, true, nullptr,
@@ -258,7 +224,7 @@ infiniStatus_t conv_kernel(
258224
(kunlunStream_t)stream,
259225
[&](xdnnHandle_t handle) {
260226
CHECK_KUNLUN((xdnn::conv3d_fusion<float, float, float, int16_t>(handle, (float *)x, (float *)w, (float *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
261-
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.kernel_dim(0), ksize,
227+
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
262228
stride, pad,
263229
dilation, 1, nullptr,
264230
nullptr, nullptr, true, (float *)bias,

test/infiniop/conv.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,6 @@
4949
(1, 2),
5050
(2, 1),
5151
),
52-
(
53-
(1, 3, 32, 32),
54-
(32 * 32 * 3, 32 * 32, 32, 1),
55-
(2, 3, 5, 5),
56-
(75, 25, 5, 1),
57-
(2, 2),
58-
(2, 2),
59-
(1, 1),
60-
),
6152
(
6253
(32, 3, 32, 32),
6354
(32 * 32 * 3, 32 * 32, 32, 1),
@@ -105,27 +96,27 @@
10596

10697

10798
def conv(x, w, stride, padding, dilation, y_tensor, bias=None):
108-
dim = len(x.shape) - 2
109-
if dim == 1:
110-
y_tensor.copy_(
111-
F.conv1d(
112-
x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
99+
match len(x.shape) - 2:
100+
case 1:
101+
y_tensor.copy_(
102+
F.conv1d(
103+
x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
104+
)
113105
)
114-
)
115-
elif dim == 2:
116-
y_tensor.copy_(
117-
F.conv2d(
118-
x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
106+
case 2:
107+
y_tensor.copy_(
108+
F.conv2d(
109+
x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
110+
)
119111
)
120-
)
121-
elif dim == 3:
122-
y_tensor.copy_(
123-
F.conv3d(
124-
x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
112+
case 3:
113+
y_tensor.copy_(
114+
F.conv3d(
115+
x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
116+
)
125117
)
126-
)
127-
else:
128-
print("Error: Pytorch -> Unsupported tensor dimension")
118+
case _:
119+
print("Error: Pytorch -> Unsupported tensor dimension")
129120

130121

131122
# infer the shape of the output given the inputs for a N-ary convolution

0 commit comments

Comments
 (0)