@@ -45,10 +45,10 @@ infiniStatus_t Descriptor::create(
4545 return INFINI_STATUS_SUCCESS;
4646}
4747
48+ template <typename Tdata>
4849infiniStatus_t conv_kernel (
4950 std::shared_ptr<device::kunlun::Handle::Internal> internal,
5051 const ConvInfo &info,
51- infiniDtype_t dtype,
5252 void *workspace,
5353 size_t workspace_size,
5454 void *y,
@@ -74,168 +74,87 @@ infiniStatus_t conv_kernel(
7474 std::initializer_list<int64_t > pad = {(int64_t )info.pad_info (0 )};
7575 int64_t dilation = (int64_t )info.dilation_info (0 );
7676
77- if (dtype == INFINI_DTYPE_F16) {
78-
79- if (bias_size > 0 ) {
80- CHECK_STATUS (internal->useXdnn (
81- (kunlunStream_t)stream,
82- [&](xdnnHandle_t handle) {
83- CHECK_KUNLUN ((xdnn::cast<float16, float >(handle, (float16 *)bias, bias_F32, bias_size)));
84- CHECK_KUNLUN ((xdnn::conv1d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
85- (int64_t )info.out_channels (), ksize,
86- stride, pad,
87- dilation, 1 , nullptr ,
88- nullptr , nullptr , true , bias_F32,
89- nullptr , baidu::xpu::api::Activation_t::LINEAR,
90- nullptr )));
91- return INFINI_STATUS_SUCCESS;
92- }));
93- } else {
94- CHECK_STATUS (internal->useXdnn (
95- (kunlunStream_t)stream,
96- [&](xdnnHandle_t handle) {
97- CHECK_KUNLUN ((xdnn::conv1d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
98- (int64_t )info.out_channels (), ksize,
99- stride, pad,
100- dilation, 1 , nullptr ,
101- nullptr , nullptr , true , nullptr ,
102- nullptr , baidu::xpu::api::Activation_t::LINEAR,
103- nullptr )));
104- return INFINI_STATUS_SUCCESS;
105- }));
106- }
107- return INFINI_STATUS_SUCCESS;
108-
109- } else if (dtype == INFINI_DTYPE_F32) {
110- CHECK_STATUS (internal->useXdnn (
111- (kunlunStream_t)stream,
112- [&](xdnnHandle_t handle) {
113- CHECK_KUNLUN ((xdnn::conv1d_fusion<float , float , float , int16_t >(handle, (float *)x, (float *)w, (float *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
114- (int64_t )info.out_channels (), ksize,
115- stride, pad,
116- dilation, 1 , nullptr ,
117- nullptr , nullptr , true , (float *)bias,
118- nullptr , baidu::xpu::api::Activation_t::LINEAR,
119- nullptr )));
120- return INFINI_STATUS_SUCCESS;
121- }));
122- } else {
123- return INFINI_STATUS_BAD_TENSOR_DTYPE;
124- }
125- break ;
77+ CHECK_STATUS (internal->useXdnn (
78+ (kunlunStream_t)stream,
79+ [&](xdnnHandle_t handle) {
80+ if (bias_size > 0 ) {
81+ if constexpr (std::is_same<Tdata, float16>::value) {
82+ CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
83+ } else if constexpr (std::is_same<Tdata, float >::value) {
84+ bias_F32 = (float *)bias;
85+ }
86+ } else {
87+ bias_F32 = nullptr ;
88+ }
89+ CHECK_KUNLUN ((xdnn::conv1d_fusion<Tdata, Tdata, Tdata, int16_t >(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
90+ (int64_t )info.out_channels (), ksize,
91+ stride, pad,
92+ dilation, 1 , nullptr ,
93+ nullptr , nullptr , true , bias_F32,
94+ nullptr , baidu::xpu::api::Activation_t::LINEAR,
95+ nullptr )));
96+ return INFINI_STATUS_SUCCESS;
97+ }));
98+ return INFINI_STATUS_SUCCESS;
12699 }
127100 case 2 : {
128101 std::vector<int64_t > ksize = {(int64_t )info.kernel_dim (0 ), (int64_t )info.kernel_dim (1 )};
129102 std::vector<int64_t > stride = {(int64_t )info.stride_info (0 ), (int64_t )info.stride_info (1 )};
130103 std::vector<int64_t > pad = {
131104 (int64_t )info.pad_info (0 ),
132- (int64_t )info.pad_info (0 ),
133- (int64_t )info.pad_info (1 ),
134105 (int64_t )info.pad_info (1 )};
135106 std::vector<int64_t > dilation = {(int64_t )info.dilation_info (0 ), (int64_t )info.dilation_info (1 )};
136-
137- if (dtype == INFINI_DTYPE_F16) {
138- if (bias_size > 0 ) {
139- CHECK_STATUS (internal->useXdnn (
140- (kunlunStream_t)stream,
141- [&](xdnnHandle_t handle) {
142- CHECK_KUNLUN ((xdnn::cast<float16, float >(handle, (float16 *)bias, bias_F32, bias_size)));
143- CHECK_KUNLUN ((xdnn::conv2d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
144- (int64_t )info.input_dim (1 ), (int64_t )info.out_channels (), ksize,
145- stride, pad,
146- dilation, 1 , nullptr ,
147- nullptr , nullptr , true , bias_F32,
148- nullptr , baidu::xpu::api::Activation_t::LINEAR, nullptr ,
149- nullptr , -1 )));
150- return INFINI_STATUS_SUCCESS;
151- }));
152- } else {
153- CHECK_STATUS (internal->useXdnn (
154- (kunlunStream_t)stream,
155- [&](xdnnHandle_t handle) {
156- CHECK_KUNLUN ((xdnn::conv2d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
157- (int64_t )info.input_dim (1 ), (int64_t )info.out_channels (), ksize,
158- stride, pad,
159- dilation, 1 , nullptr ,
160- nullptr , nullptr , true , nullptr ,
161- nullptr , baidu::xpu::api::Activation_t::LINEAR, nullptr ,
162- nullptr , -1 )));
163- return INFINI_STATUS_SUCCESS;
164- }));
165- }
166- return INFINI_STATUS_SUCCESS;
167-
168- } else if (dtype == INFINI_DTYPE_F32) {
169- CHECK_STATUS (internal->useXdnn (
170- (kunlunStream_t)stream,
171- [&](xdnnHandle_t handle) {
172- CHECK_KUNLUN ((xdnn::conv2d_fusion<float , float , float , int16_t >(handle, (float *)x, (float *)w, (float *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
173- (int64_t )info.input_dim (1 ), (int64_t )info.out_channels (), ksize,
174- stride, pad,
175- dilation, 1 , nullptr ,
176- nullptr , nullptr , true , (float *)bias,
177- nullptr , baidu::xpu::api::Activation_t::LINEAR, nullptr ,
178- nullptr , -1 )));
179- return INFINI_STATUS_SUCCESS;
180- }));
181- } else {
182- return INFINI_STATUS_BAD_TENSOR_DTYPE;
183- }
184- break ;
107+ CHECK_STATUS (internal->useXdnn (
108+ (kunlunStream_t)stream,
109+ [&](xdnnHandle_t handle) {
110+ if (bias_size > 0 ) {
111+ if constexpr (std::is_same<Tdata, float16>::value) {
112+ CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
113+ } else if constexpr (std::is_same<Tdata, float >::value) {
114+ bias_F32 = (float *)bias;
115+ }
116+ } else {
117+ bias_F32 = nullptr ;
118+ }
119+ CHECK_KUNLUN ((xdnn::conv2d_fusion<Tdata, Tdata, Tdata, int16_t >(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
120+ (int64_t )info.input_dim (1 ), (int64_t )info.out_channels (), ksize,
121+ stride, pad,
122+ dilation, 1 , nullptr ,
123+ nullptr , nullptr , true , bias_F32,
124+ nullptr , baidu::xpu::api::Activation_t::LINEAR, nullptr ,
125+ nullptr , -1 )));
126+ return INFINI_STATUS_SUCCESS;
127+ }));
128+ return INFINI_STATUS_SUCCESS;
185129 }
186130 case 3 : {
187131 std::vector<int64_t > ksize = {(int64_t )info.kernel_dim (0 ), (int64_t )info.kernel_dim (1 ), (int64_t )info.kernel_dim (2 )};
188132 std::vector<int64_t > stride = {(int64_t )info.stride_info (0 ), (int64_t )info.stride_info (1 ), (int64_t )info.stride_info (2 )};
189133 std::vector<int64_t > pad = {(int64_t )info.pad_info (0 ), (int64_t )info.pad_info (1 ), (int64_t )info.pad_info (2 )};
190134 std::vector<int64_t > dilation = {(int64_t )info.dilation_info (0 ), (int64_t )info.dilation_info (1 ), (int64_t )info.dilation_info (2 )};
191135
192- if (dtype == INFINI_DTYPE_F16) {
193- if (bias_size > 0 ) {
194- CHECK_STATUS (internal->useXdnn (
195- (kunlunStream_t)stream,
196- [&](xdnnHandle_t handle) {
197- CHECK_KUNLUN ((xdnn::cast<float16, float >(handle, (float16 *)bias, bias_F32, bias_size)));
198- CHECK_KUNLUN ((xdnn::conv3d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
199- (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels (), ksize,
200- stride, pad,
201- dilation, 1 , nullptr ,
202- nullptr , nullptr , true , bias_F32,
203- nullptr , baidu::xpu::api::Activation_t::LINEAR,
204- nullptr )));
205- return INFINI_STATUS_SUCCESS;
206- }));
207- } else {
208- CHECK_STATUS (internal->useXdnn (
209- (kunlunStream_t)stream,
210- [&](xdnnHandle_t handle) {
211- CHECK_KUNLUN ((xdnn::conv3d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
212- (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels (), ksize,
213- stride, pad,
214- dilation, 1 , nullptr ,
215- nullptr , nullptr , true , nullptr ,
216- nullptr , baidu::xpu::api::Activation_t::LINEAR,
217- nullptr )));
218- return INFINI_STATUS_SUCCESS;
219- }));
220- }
221- return INFINI_STATUS_SUCCESS;
222- } else if (dtype == INFINI_DTYPE_F32) {
223- CHECK_STATUS (internal->useXdnn (
224- (kunlunStream_t)stream,
225- [&](xdnnHandle_t handle) {
226- CHECK_KUNLUN ((xdnn::conv3d_fusion<float , float , float , int16_t >(handle, (float *)x, (float *)w, (float *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
227- (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels (), ksize,
228- stride, pad,
229- dilation, 1 , nullptr ,
230- nullptr , nullptr , true , (float *)bias,
231- nullptr , baidu::xpu::api::Activation_t::LINEAR,
232- nullptr )));
233- return INFINI_STATUS_SUCCESS;
234- }));
235- } else {
236- return INFINI_STATUS_BAD_TENSOR_DTYPE;
237- }
238- break ;
136+ CHECK_STATUS (internal->useXdnn (
137+ (kunlunStream_t)stream,
138+ [&](xdnnHandle_t handle) {
139+ if (bias_size > 0 ) {
140+ if constexpr (std::is_same<Tdata, float16>::value) {
141+ CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
142+ } else if constexpr (std::is_same<Tdata, float >::value) {
143+ bias_F32 = (float *)bias;
144+ }
145+ } else {
146+ bias_F32 = nullptr ;
147+ }
148+ CHECK_KUNLUN ((xdnn::conv3d_fusion<Tdata, Tdata, Tdata, int16_t >(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
149+ (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels (), ksize,
150+ stride, pad,
151+ dilation, 1 , nullptr ,
152+ nullptr , nullptr , true , bias_F32,
153+ nullptr , baidu::xpu::api::Activation_t::LINEAR,
154+ nullptr )));
155+ return INFINI_STATUS_SUCCESS;
156+ }));
157+ return INFINI_STATUS_SUCCESS;
239158 }
240159 default :
241160 return INFINI_STATUS_BAD_TENSOR_SHAPE;
@@ -254,17 +173,31 @@ infiniStatus_t Descriptor::calculate(
254173 if (workspace_size < _workspace_size) {
255174 return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
256175 }
257- CHECK_STATUS (conv_kernel (
258- _opaque->internal ,
259- _info,
260- _dtype,
261- workspace,
262- workspace_size,
263- y,
264- x,
265- w,
266- bias,
267- stream));
176+ if (_dtype == INFINI_DTYPE_F16) {
177+ CHECK_STATUS (conv_kernel<float16>(
178+ _opaque->internal ,
179+ _info,
180+ workspace,
181+ workspace_size,
182+ y,
183+ x,
184+ w,
185+ bias,
186+ stream));
187+ } else if (_dtype == INFINI_DTYPE_F32) {
188+ CHECK_STATUS (conv_kernel<float >(
189+ _opaque->internal ,
190+ _info,
191+ workspace,
192+ workspace_size,
193+ y,
194+ x,
195+ w,
196+ bias,
197+ stream));
198+ } else {
199+ return INFINI_STATUS_BAD_TENSOR_DTYPE;
200+ }
268201
269202 return INFINI_STATUS_SUCCESS;
270203}
0 commit comments