Skip to content

Commit 983a45f

Browse files
committed
issue/360: optim kunlun code
1 parent 84da841 commit 983a45f

File tree

1 file changed

+14
-27
lines changed

1 file changed

+14
-27
lines changed

src/infiniop/ops/conv/kunlun/conv_kunlun.cc

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ infiniStatus_t conv_kernel(
6767
bias_size = 0;
6868
}
6969
float *bias_F32 = (float *)workspace_value;
70+
CHECK_STATUS(internal->useXdnn(
71+
(kunlunStream_t)stream,
72+
[&](xdnnHandle_t handle) {
73+
if (bias_size > 0) {
74+
if constexpr (std::is_same<Tdata, float16>::value) {
75+
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
76+
} else if constexpr (std::is_same<Tdata, float>::value) {
77+
bias_F32 = (float *)bias;
78+
}
79+
} else {
80+
bias_F32 = nullptr;
81+
}
82+
return INFINI_STATUS_SUCCESS;
83+
}));
7084
switch (info.ndim()) {
7185
case 1: {
7286
int64_t ksize = (int64_t)info.kernel_dim(0);
@@ -77,15 +91,6 @@ infiniStatus_t conv_kernel(
7791
CHECK_STATUS(internal->useXdnn(
7892
(kunlunStream_t)stream,
7993
[&](xdnnHandle_t handle) {
80-
if (bias_size > 0) {
81-
if constexpr (std::is_same<Tdata, float16>::value) {
82-
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
83-
} else if constexpr (std::is_same<Tdata, float>::value) {
84-
bias_F32 = (float *)bias;
85-
}
86-
} else {
87-
bias_F32 = nullptr;
88-
}
8994
CHECK_KUNLUN((xdnn::conv1d_fusion<Tdata, Tdata, Tdata, int16_t>(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
9095
(int64_t)info.out_channels(), ksize,
9196
stride, pad,
@@ -107,15 +112,6 @@ infiniStatus_t conv_kernel(
107112
CHECK_STATUS(internal->useXdnn(
108113
(kunlunStream_t)stream,
109114
[&](xdnnHandle_t handle) {
110-
if (bias_size > 0) {
111-
if constexpr (std::is_same<Tdata, float16>::value) {
112-
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
113-
} else if constexpr (std::is_same<Tdata, float>::value) {
114-
bias_F32 = (float *)bias;
115-
}
116-
} else {
117-
bias_F32 = nullptr;
118-
}
119115
CHECK_KUNLUN((xdnn::conv2d_fusion<Tdata, Tdata, Tdata, int16_t>(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
120116
(int64_t)info.input_dim(1), (int64_t)info.out_channels(), ksize,
121117
stride, pad,
@@ -136,15 +132,6 @@ infiniStatus_t conv_kernel(
136132
CHECK_STATUS(internal->useXdnn(
137133
(kunlunStream_t)stream,
138134
[&](xdnnHandle_t handle) {
139-
if (bias_size > 0) {
140-
if constexpr (std::is_same<Tdata, float16>::value) {
141-
CHECK_KUNLUN((xdnn::cast<Tdata, float>(handle, (Tdata *)bias, bias_F32, bias_size)));
142-
} else if constexpr (std::is_same<Tdata, float>::value) {
143-
bias_F32 = (float *)bias;
144-
}
145-
} else {
146-
bias_F32 = nullptr;
147-
}
148135
CHECK_KUNLUN((xdnn::conv3d_fusion<Tdata, Tdata, Tdata, int16_t>(handle, (Tdata *)x, (Tdata *)w, (Tdata *)y, (int64_t)info.batch(), (int64_t)info.in_channels(), (int64_t)info.input_dim(0),
149136
(int64_t)info.input_dim(1), (int64_t)info.input_dim(2), (int64_t)info.out_channels(), ksize,
150137
stride, pad,

0 commit comments

Comments
 (0)