@@ -67,6 +67,20 @@ infiniStatus_t conv_kernel(
6767 bias_size = 0 ;
6868 }
6969 float *bias_F32 = (float *)workspace_value;
70+ CHECK_STATUS (internal->useXdnn (
71+ (kunlunStream_t)stream,
72+ [&](xdnnHandle_t handle) {
73+ if (bias_size > 0 ) {
74+ if constexpr (std::is_same<Tdata, float16>::value) {
75+ CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
76+ } else if constexpr (std::is_same<Tdata, float >::value) {
77+ bias_F32 = (float *)bias;
78+ }
79+ } else {
80+ bias_F32 = nullptr ;
81+ }
82+ return INFINI_STATUS_SUCCESS;
83+ }));
7084 switch (info.ndim ()) {
7185 case 1 : {
7286 int64_t ksize = (int64_t )info.kernel_dim (0 );
@@ -77,15 +91,6 @@ infiniStatus_t conv_kernel(
7791 CHECK_STATUS (internal->useXdnn (
7892 (kunlunStream_t)stream,
7993 [&](xdnnHandle_t handle) {
80- if (bias_size > 0 ) {
81- if constexpr (std::is_same<Tdata, float16>::value) {
82- CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
83- } else if constexpr (std::is_same<Tdata, float >::value) {
84- bias_F32 = (float *)bias;
85- }
86- } else {
87- bias_F32 = nullptr ;
88- }
8994 CHECK_KUNLUN ((xdnn::conv1d_fusion<Tdata, Tdata, Tdata, int16_t >(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
9095 (int64_t )info.out_channels (), ksize,
9196 stride, pad,
@@ -107,15 +112,6 @@ infiniStatus_t conv_kernel(
107112 CHECK_STATUS (internal->useXdnn (
108113 (kunlunStream_t)stream,
109114 [&](xdnnHandle_t handle) {
110- if (bias_size > 0 ) {
111- if constexpr (std::is_same<Tdata, float16>::value) {
112- CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
113- } else if constexpr (std::is_same<Tdata, float >::value) {
114- bias_F32 = (float *)bias;
115- }
116- } else {
117- bias_F32 = nullptr ;
118- }
119115 CHECK_KUNLUN ((xdnn::conv2d_fusion<Tdata, Tdata, Tdata, int16_t >(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
120116 (int64_t )info.input_dim (1 ), (int64_t )info.out_channels (), ksize,
121117 stride, pad,
@@ -136,15 +132,6 @@ infiniStatus_t conv_kernel(
136132 CHECK_STATUS (internal->useXdnn (
137133 (kunlunStream_t)stream,
138134 [&](xdnnHandle_t handle) {
139- if (bias_size > 0 ) {
140- if constexpr (std::is_same<Tdata, float16>::value) {
141- CHECK_KUNLUN ((xdnn::cast<Tdata, float >(handle, (Tdata *)bias, bias_F32, bias_size)));
142- } else if constexpr (std::is_same<Tdata, float >::value) {
143- bias_F32 = (float *)bias;
144- }
145- } else {
146- bias_F32 = nullptr ;
147- }
148135 CHECK_KUNLUN ((xdnn::conv3d_fusion<Tdata, Tdata, Tdata, int16_t >(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
149136 (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels (), ksize,
150137 stride, pad,
0 commit comments