@@ -73,39 +73,16 @@ infiniStatus_t conv_kernel(
7373 int64_t stride = (int64_t )info.stride_info (0 );
7474 std::initializer_list<int64_t > pad = {(int64_t )info.pad_info (0 )};
7575 int64_t dilation = (int64_t )info.dilation_info (0 );
76- printf (" x_shape:(%ld, %ld, %ld)\n " , info.batch (), info.in_channels (), info.input_dim (0 ));
77- printf (" kernel_dim:(%ld)\n " , ksize);
78- printf (" stride:(%ld)\n " , stride);
79- printf (" pad:(%ld)\n " , (int64_t )info.pad_info (0 ));
80- printf (" dilation:(%ld)\n " , dilation);
81- std::cout << " ndim: " << info.ndim () << " bias_size: " << bias_size << std::endl;
76+
8277 if (dtype == INFINI_DTYPE_F16) {
83- // float16 *host_x, *host_w, *host_bias;
84- // host_x = (float16 *)malloc((int)info.batch() * (int)info.in_channels() * (int)info.input_dim(0) * sizeof(float16));
85- // host_w = (float16 *)malloc((int)bias_size * (int)info.in_channels() * (int)info.kernel_dim(0) * sizeof(float16));
86- // host_bias = (float16 *)malloc((int)bias_size * sizeof(float16));
87- // xpu_memcpy(host_x, x, (int)info.batch() * (int)info.in_channels() * (int)info.input_dim(0) * sizeof(float16), XPU_DEVICE_TO_HOST);
88- // xpu_memcpy(host_w, w, (int)bias_size * (int)info.in_channels() * (int)info.kernel_dim(0) * sizeof(float16), XPU_DEVICE_TO_HOST);
89- // xpu_memcpy(host_bias, bias, (int)bias_size * sizeof(float16), XPU_DEVICE_TO_HOST);
90- // for (int i = 0; i < (int)info.batch() * (int)info.in_channels() * (int)info.input_dim(0); i++) {
91- // printf("%.4f ", static_cast<float>(host_x[i]));
92- // }
93- // printf("\n");
94- // for (int i = 0; i < (int)bias_size * (int)info.in_channels() * (int)info.kernel_dim(0); i++) {
95- // printf("%.4f ", static_cast<float>(host_w[i]));
96- // }
97- // printf("\n");
98- // for (int i = 0; i < (int)bias_size; i++) {
99- // printf("%.4f ", static_cast<float>(host_bias[i]));
100- // }
101- // printf("\n");
78+
10279 if (bias_size > 0 ) {
10380 CHECK_STATUS (internal->useXdnn (
10481 (kunlunStream_t)stream,
10582 [&](xdnnHandle_t handle) {
10683 CHECK_KUNLUN ((xdnn::cast<float16, float >(handle, (float16 *)bias, bias_F32, bias_size)));
10784 CHECK_KUNLUN ((xdnn::conv1d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
108- (int64_t )info.kernel_dim ( 0 ), ksize,
85+ (int64_t )info.out_channels ( ), ksize,
10986 stride, pad,
11087 dilation, 1 , nullptr ,
11188 nullptr , nullptr , true , bias_F32,
@@ -118,7 +95,7 @@ infiniStatus_t conv_kernel(
11895 (kunlunStream_t)stream,
11996 [&](xdnnHandle_t handle) {
12097 CHECK_KUNLUN ((xdnn::conv1d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
121- (int64_t )info.kernel_dim ( 0 ), ksize,
98+ (int64_t )info.out_channels ( ), ksize,
12299 stride, pad,
123100 dilation, 1 , nullptr ,
124101 nullptr , nullptr , true , nullptr ,
@@ -134,7 +111,7 @@ infiniStatus_t conv_kernel(
134111 (kunlunStream_t)stream,
135112 [&](xdnnHandle_t handle) {
136113 CHECK_KUNLUN ((xdnn::conv1d_fusion<float , float , float , int16_t >(handle, (float *)x, (float *)w, (float *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
137- (int64_t )info.kernel_dim ( 0 ), ksize,
114+ (int64_t )info.out_channels ( ), ksize,
138115 stride, pad,
139116 dilation, 1 , nullptr ,
140117 nullptr , nullptr , true , (float *)bias,
@@ -156,20 +133,15 @@ infiniStatus_t conv_kernel(
156133 (int64_t )info.pad_info (1 ),
157134 (int64_t )info.pad_info (1 )};
158135 std::vector<int64_t > dilation = {(int64_t )info.dilation_info (0 ), (int64_t )info.dilation_info (1 )};
159- printf (" x_shape:(%ld, %ld, %ld, %ld)\n " , info.batch (), info.in_channels (), info.input_dim (0 ), info.input_dim (1 ));
160- printf (" kernel_dim:(%ld, %ld)\n " , ksize[0 ], ksize[1 ]);
161- printf (" stride:(%ld, %ld)\n " , stride[0 ], stride[1 ]);
162- printf (" pad:(%ld, %ld)\n " , pad[0 ], pad[1 ]);
163- printf (" dilation:(%ld, %ld)\n " , dilation[0 ], dilation[1 ]);
164- std::cout << " ndim: " << info.ndim () << " bias_size: " << bias_size << std::endl;
136+
165137 if (dtype == INFINI_DTYPE_F16) {
166138 if (bias_size > 0 ) {
167139 CHECK_STATUS (internal->useXdnn (
168140 (kunlunStream_t)stream,
169141 [&](xdnnHandle_t handle) {
170142 CHECK_KUNLUN ((xdnn::cast<float16, float >(handle, (float16 *)bias, bias_F32, bias_size)));
171143 CHECK_KUNLUN ((xdnn::conv2d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
172- (int64_t )info.input_dim (1 ), (int64_t )info.kernel_dim ( 0 ), ksize,
144+ (int64_t )info.input_dim (1 ), (int64_t )info.out_channels ( ), ksize,
173145 stride, pad,
174146 dilation, 1 , nullptr ,
175147 nullptr , nullptr , true , bias_F32,
@@ -182,7 +154,7 @@ infiniStatus_t conv_kernel(
182154 (kunlunStream_t)stream,
183155 [&](xdnnHandle_t handle) {
184156 CHECK_KUNLUN ((xdnn::conv2d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
185- (int64_t )info.input_dim (1 ), (int64_t )info.kernel_dim ( 0 ), ksize,
157+ (int64_t )info.input_dim (1 ), (int64_t )info.out_channels ( ), ksize,
186158 stride, pad,
187159 dilation, 1 , nullptr ,
188160 nullptr , nullptr , true , nullptr ,
@@ -198,7 +170,7 @@ infiniStatus_t conv_kernel(
198170 (kunlunStream_t)stream,
199171 [&](xdnnHandle_t handle) {
200172 CHECK_KUNLUN ((xdnn::conv2d_fusion<float , float , float , int16_t >(handle, (float *)x, (float *)w, (float *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
201- (int64_t )info.input_dim (1 ), (int64_t )info.kernel_dim ( 0 ), ksize,
173+ (int64_t )info.input_dim (1 ), (int64_t )info.out_channels ( ), ksize,
202174 stride, pad,
203175 dilation, 1 , nullptr ,
204176 nullptr , nullptr , true , (float *)bias,
@@ -217,20 +189,14 @@ infiniStatus_t conv_kernel(
217189 std::vector<int64_t > pad = {(int64_t )info.pad_info (0 ), (int64_t )info.pad_info (1 ), (int64_t )info.pad_info (2 )};
218190 std::vector<int64_t > dilation = {(int64_t )info.dilation_info (0 ), (int64_t )info.dilation_info (1 ), (int64_t )info.dilation_info (2 )};
219191
220- printf (" x_shape:(%ld, %ld, %ld, %ld, %ld)\n " , info.batch (), info.in_channels (), info.input_dim (0 ), info.input_dim (1 ), info.input_dim (2 ));
221- printf (" kernel_dim:(%ld, %ld, %ld)\n " , ksize[0 ], ksize[1 ], ksize[2 ]);
222- printf (" stride:(%ld, %ld, %ld)\n " , stride[0 ], stride[1 ], stride[2 ]);
223- printf (" pad:(%ld, %ld, %ld)\n " , pad[0 ], pad[1 ], pad[2 ]);
224- printf (" dilation:(%ld, %ld, %ld)\n " , dilation[0 ], dilation[1 ], dilation[2 ]);
225- std::cout << " ndim: " << info.ndim () << " bias_size: " << bias_size << std::endl;
226192 if (dtype == INFINI_DTYPE_F16) {
227193 if (bias_size > 0 ) {
228194 CHECK_STATUS (internal->useXdnn (
229195 (kunlunStream_t)stream,
230196 [&](xdnnHandle_t handle) {
231197 CHECK_KUNLUN ((xdnn::cast<float16, float >(handle, (float16 *)bias, bias_F32, bias_size)));
232198 CHECK_KUNLUN ((xdnn::conv3d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
233- (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.kernel_dim ( 0 ), ksize,
199+ (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels ( ), ksize,
234200 stride, pad,
235201 dilation, 1 , nullptr ,
236202 nullptr , nullptr , true , bias_F32,
@@ -243,7 +209,7 @@ infiniStatus_t conv_kernel(
243209 (kunlunStream_t)stream,
244210 [&](xdnnHandle_t handle) {
245211 CHECK_KUNLUN ((xdnn::conv3d_fusion<float16, float16, float16, int16_t >(handle, (float16 *)x, (float16 *)w, (float16 *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
246- (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.kernel_dim ( 0 ), ksize,
212+ (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels ( ), ksize,
247213 stride, pad,
248214 dilation, 1 , nullptr ,
249215 nullptr , nullptr , true , nullptr ,
@@ -258,7 +224,7 @@ infiniStatus_t conv_kernel(
258224 (kunlunStream_t)stream,
259225 [&](xdnnHandle_t handle) {
260226 CHECK_KUNLUN ((xdnn::conv3d_fusion<float , float , float , int16_t >(handle, (float *)x, (float *)w, (float *)y, (int64_t )info.batch (), (int64_t )info.in_channels (), (int64_t )info.input_dim (0 ),
261- (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.kernel_dim ( 0 ), ksize,
227+ (int64_t )info.input_dim (1 ), (int64_t )info.input_dim (2 ), (int64_t )info.out_channels ( ), ksize,
262228 stride, pad,
263229 dilation, 1 , nullptr ,
264230 nullptr , nullptr , true , (float *)bias,
0 commit comments