Skip to content

Commit 84da841

Browse files
committed
issue/360: optimize code
1 parent 78b7873 commit 84da841

File tree

1 file changed

+92
-159
lines changed

1 file changed

+92
-159
lines changed

src/infiniop/ops/conv/kunlun/conv_kunlun.cc

Lines changed: 92 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ infiniStatus_t Descriptor::create(
4545
return INFINI_STATUS_SUCCESS;
4646
}
4747

48+
template <typename Tdata>
4849
infiniStatus_t conv_kernel(
4950
std::shared_ptr<device::kunlun::Handle::Internal> internal,
5051
const ConvInfo &info,
51-
infiniDtype_t dtype,
5252
void *workspace,
5353
size_t workspace_size,
5454
void *y,
@@ -74,168 +74,87 @@ infiniStatus_t conv_kernel(
7474
std::initializer_list<int64_t> pad = {(int64_t)info.pad_info(0)};
7575
int64_t dilation = (int64_t)info.dilation_info(0);
7676

77-
if (dtype == INFINI_DTYPE_F16) {
78-
79-
if (bias_size > 0) {
80-
CHECK_STATUS(internal->useXdnn(
81-
(kunlunStream_t)stream,
82-
[&](xdnnHandle_t handle) {
83-
CHECK_KUNLUN((xdnn::cast<float16, float>(handle, (float16 *)bias, bias_F32, bias_size)));
84-
CHECK_KUNLUN((xdnn::conv1d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
85-
(int64_t)info.out_channels(), ksize,
86-
stride, pad,
87-
dilation, 1, nullptr,
88-
nullptr, nullptr, true, bias_F32,
89-
nullptr, baidu::xpu::api::Activation_t::LINEAR,
90-
nullptr)));
91-
return INFINI_STATUS_SUCCESS;
92-
}));
93-
} else {
94-
CHECK_STATUS(internal->useXdnn(
95-
(kunlunStream_t)stream,
96-
[&](xdnnHandle_t handle) {
97-
CHECK_KUNLUN((xdnn::conv1d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
98-
(int64_t)info.out_channels(), ksize,
99-
stride, pad,
100-
dilation, 1, nullptr,
101-
nullptr, nullptr, true, nullptr,
102-
nullptr, baidu::xpu::api::Activation_t::LINEAR,
103-
nullptr)));
104-
return INFINI_STATUS_SUCCESS;
105-
}));
106-
}
107-
return INFINI_STATUS_SUCCESS;
108-
109-
} else if (dtype == INFINI_DTYPE_F32) {
110-
CHECK_STATUS(internal->useXdnn(
111-
(kunlunStream_t)stream,
112-
[&](xdnnHandle_t handle) {
113-
CHECK_KUNLUN((xdnn::conv1d_fusion<float, float, float, int16_t>(handle, (float *)x, (float *)w, (float *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
114-
(int64_t)info.out_channels(), ksize,
115-
stride, pad,
116-
dilation, 1, nullptr,
117-
nullptr, nullptr, true, (float *)bias,
118-
nullptr, baidu::xpu::api::Activation_t::LINEAR,
119-
nullptr)));
120-
return INFINI_STATUS_SUCCESS;
121-
}));
122-
} else {
123-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
124-
}
125-
break;
77+
CHECK_STATUS(internal->useXdnn(
78+
(kunlunStream_t)stream,
79+
[&](xdnnHandle_t handle) {
80+
if (bias_size > 0) {
81+
if constexpr (std::is_same<Tdata, float16>::value) {
82+
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
83+
} else if constexpr (std::is_same<Tdata, float>::value) {
84+
bias_F32 = (float *)bias;
85+
}
86+
} else {
87+
bias_F32 = nullptr;
88+
}
89+
CHECK_KUNLUN((xdnn::conv1d_fusion<Tdata, Tdata, Tdata, int16_t>(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
90+
(int64_t)info.out_channels(), ksize,
91+
stride, pad,
92+
dilation, 1, nullptr,
93+
nullptr, nullptr, true, bias_F32,
94+
nullptr, baidu::xpu::api::Activation_t::LINEAR,
95+
nullptr)));
96+
return INFINI_STATUS_SUCCESS;
97+
}));
98+
return INFINI_STATUS_SUCCESS;
12699
}
127100
case 2: {
128101
std::vector<int64_t> ksize = {(int64_t)info.kernel_dim(0), (int64_t)info.kernel_dim(1)};
129102
std::vector<int64_t> stride = {(int64_t)info.stride_info(0), (int64_t)info.stride_info(1)};
130103
std::vector<int64_t> pad = {
131104
(int64_t)info.pad_info(0),
132-
(int64_t)info.pad_info(0),
133-
(int64_t)info.pad_info(1),
134105
(int64_t)info.pad_info(1)};
135106
std::vector<int64_t> dilation = {(int64_t)info.dilation_info(0), (int64_t)info.dilation_info(1)};
136-
137-
if (dtype == INFINI_DTYPE_F16) {
138-
if (bias_size > 0) {
139-
CHECK_STATUS(internal->useXdnn(
140-
(kunlunStream_t)stream,
141-
[&](xdnnHandle_t handle) {
142-
CHECK_KUNLUN((xdnn::cast<float16, float>(handle, (float16 *)bias, bias_F32, bias_size)));
143-
CHECK_KUNLUN((xdnn::conv2d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
144-
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
145-
stride, pad,
146-
dilation, 1, nullptr,
147-
nullptr, nullptr, true, bias_F32,
148-
nullptr, baidu::xpu::api::Activation_t::LINEAR, nullptr,
149-
nullptr, -1)));
150-
return INFINI_STATUS_SUCCESS;
151-
}));
152-
} else {
153-
CHECK_STATUS(internal->useXdnn(
154-
(kunlunStream_t)stream,
155-
[&](xdnnHandle_t handle) {
156-
CHECK_KUNLUN((xdnn::conv2d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
157-
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
158-
stride, pad,
159-
dilation, 1, nullptr,
160-
nullptr, nullptr, true, nullptr,
161-
nullptr, baidu::xpu::api::Activation_t::LINEAR, nullptr,
162-
nullptr, -1)));
163-
return INFINI_STATUS_SUCCESS;
164-
}));
165-
}
166-
return INFINI_STATUS_SUCCESS;
167-
168-
} else if (dtype == INFINI_DTYPE_F32) {
169-
CHECK_STATUS(internal->useXdnn(
170-
(kunlunStream_t)stream,
171-
[&](xdnnHandle_t handle) {
172-
CHECK_KUNLUN((xdnn::conv2d_fusion<float, float, float, int16_t>(handle, (float *)x, (float *)w, (float *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
173-
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
174-
stride, pad,
175-
dilation, 1, nullptr,
176-
nullptr, nullptr, true, (float *)bias,
177-
nullptr, baidu::xpu::api::Activation_t::LINEAR, nullptr,
178-
nullptr, -1)));
179-
return INFINI_STATUS_SUCCESS;
180-
}));
181-
} else {
182-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
183-
}
184-
break;
107+
CHECK_STATUS(internal->useXdnn(
108+
(kunlunStream_t)stream,
109+
[&](xdnnHandle_t handle) {
110+
if (bias_size > 0) {
111+
if constexpr (std::is_same<Tdata, float16>::value) {
112+
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
113+
} else if constexpr (std::is_same<Tdata, float>::value) {
114+
bias_F32 = (float *)bias;
115+
}
116+
} else {
117+
bias_F32 = nullptr;
118+
}
119+
CHECK_KUNLUN((xdnn::conv2d_fusion<Tdata, Tdata, Tdata, int16_t>(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
120+
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
121+
stride, pad,
122+
dilation, 1, nullptr,
123+
nullptr, nullptr, true, bias_F32,
124+
nullptr, baidu::xpu::api::Activation_t::LINEAR, nullptr,
125+
nullptr, -1)));
126+
return INFINI_STATUS_SUCCESS;
127+
}));
128+
return INFINI_STATUS_SUCCESS;
185129
}
186130
case 3: {
187131
std::vector<int64_t> ksize = {(int64_t)info.kernel_dim(0), (int64_t)info.kernel_dim(1), (int64_t)info.kernel_dim(2)};
188132
std::vector<int64_t> stride = {(int64_t)info.stride_info(0), (int64_t)info.stride_info(1), (int64_t)info.stride_info(2)};
189133
std::vector<int64_t> pad = {(int64_t)info.pad_info(0), (int64_t)info.pad_info(1), (int64_t)info.pad_info(2)};
190134
std::vector<int64_t> dilation = {(int64_t)info.dilation_info(0), (int64_t)info.dilation_info(1), (int64_t)info.dilation_info(2)};
191135

192-
if (dtype == INFINI_DTYPE_F16) {
193-
if (bias_size > 0) {
194-
CHECK_STATUS(internal->useXdnn(
195-
(kunlunStream_t)stream,
196-
[&](xdnnHandle_t handle) {
197-
CHECK_KUNLUN((xdnn::cast<float16, float>(handle, (float16 *)bias, bias_F32, bias_size)));
198-
CHECK_KUNLUN((xdnn::conv3d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
199-
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
200-
stride, pad,
201-
dilation, 1, nullptr,
202-
nullptr, nullptr, true, bias_F32,
203-
nullptr, baidu::xpu::api::Activation_t::LINEAR,
204-
nullptr)));
205-
return INFINI_STATUS_SUCCESS;
206-
}));
207-
} else {
208-
CHECK_STATUS(internal->useXdnn(
209-
(kunlunStream_t)stream,
210-
[&](xdnnHandle_t handle) {
211-
CHECK_KUNLUN((xdnn::conv3d_fusion<float16, float16, float16, int16_t>(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
212-
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
213-
stride, pad,
214-
dilation, 1, nullptr,
215-
nullptr, nullptr, true, nullptr,
216-
nullptr, baidu::xpu::api::Activation_t::LINEAR,
217-
nullptr)));
218-
return INFINI_STATUS_SUCCESS;
219-
}));
220-
}
221-
return INFINI_STATUS_SUCCESS;
222-
} else if (dtype == INFINI_DTYPE_F32) {
223-
CHECK_STATUS(internal->useXdnn(
224-
(kunlunStream_t)stream,
225-
[&](xdnnHandle_t handle) {
226-
CHECK_KUNLUN((xdnn::conv3d_fusion<float, float, float, int16_t>(handle, (float *)x, (float *)w, (float *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
227-
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
228-
stride, pad,
229-
dilation, 1, nullptr,
230-
nullptr, nullptr, true, (float *)bias,
231-
nullptr, baidu::xpu::api::Activation_t::LINEAR,
232-
nullptr)));
233-
return INFINI_STATUS_SUCCESS;
234-
}));
235-
} else {
236-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
237-
}
238-
break;
136+
CHECK_STATUS(internal->useXdnn(
137+
(kunlunStream_t)stream,
138+
[&](xdnnHandle_t handle) {
139+
if (bias_size > 0) {
140+
if constexpr (std::is_same<Tdata, float16>::value) {
141+
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
142+
} else if constexpr (std::is_same<Tdata, float>::value) {
143+
bias_F32 = (float *)bias;
144+
}
145+
} else {
146+
bias_F32 = nullptr;
147+
}
148+
CHECK_KUNLUN((xdnn::conv3d_fusion<Tdata, Tdata, Tdata, int16_t>(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
149+
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
150+
stride, pad,
151+
dilation, 1, nullptr,
152+
nullptr, nullptr, true, bias_F32,
153+
nullptr, baidu::xpu::api::Activation_t::LINEAR,
154+
nullptr)));
155+
return INFINI_STATUS_SUCCESS;
156+
}));
157+
return INFINI_STATUS_SUCCESS;
239158
}
240159
default:
241160
return INFINI_STATUS_BAD_TENSOR_SHAPE;
@@ -254,17 +173,31 @@ infiniStatus_t Descriptor::calculate(
254173
if (workspace_size < _workspace_size) {
255174
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
256175
}
257-
CHECK_STATUS(conv_kernel(
258-
_opaque->internal,
259-
_info,
260-
_dtype,
261-
workspace,
262-
workspace_size,
263-
y,
264-
x,
265-
w,
266-
bias,
267-
stream));
176+
if (_dtype == INFINI_DTYPE_F16) {
177+
CHECK_STATUS(conv_kernel<float16>(
178+
_opaque->internal,
179+
_info,
180+
workspace,
181+
workspace_size,
182+
y,
183+
x,
184+
w,
185+
bias,
186+
stream));
187+
} else if (_dtype == INFINI_DTYPE_F32) {
188+
CHECK_STATUS(conv_kernel<float>(
189+
_opaque->internal,
190+
_info,
191+
workspace,
192+
workspace_size,
193+
y,
194+
x,
195+
w,
196+
bias,
197+
stream));
198+
} else {
199+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
200+
}
268201

269202
return INFINI_STATUS_SUCCESS;
270203
}

0 commit comments

Comments
 (0)