Skip to content

Commit 33cc3fb

Browse files
committed
feat(core/op): add naive implementation of interelem op and releated test.
1 parent f76fe8a commit 33cc3fb

File tree

6 files changed

+187
-2
lines changed

6 files changed

+187
-2
lines changed

core/op/common/interelem.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "core/op/common/ops.h"
2+
3+
namespace nncore {
4+
namespace opr {
5+
6+
IMPL_DOUBLE_INPUT_LAYOUT_DEDUCE(interelem) {
7+
auto dim = std::max(a.ndim, b.ndim);
8+
for (nn_size i = 0, j = 0, k = 0; i < a.ndim && j < b.ndim; i++, j++, k++) {
9+
nn_size a_idx = a.ndim - i - 1;
10+
nn_size b_idx = b.ndim - j - 1;
11+
if (a.shape[a_idx] != b.shape[b_idx] && a.shape[a_idx] != 1 &&
12+
b.shape[b_idx] != 1) {
13+
return Status(
14+
StatusCategory::NUMNET, StatusCode::MISMATCHED_SHAPE,
15+
"Cannot broadcast bool index to the shape of target tensor.");
16+
} else if (a.shape[a_idx] == b.shape[b_idx]) {
17+
res.shape[dim - k - 1] = a.shape[a_idx];
18+
} else if (a.shape[a_idx] == 1) {
19+
res.shape[dim - k - 1] = b.shape[b_idx];
20+
} else if (b.shape[a_idx] == 1) {
21+
res.shape[dim - k - 1] = a.shape[a_idx];
22+
} else {
23+
return Status(StatusCategory::NUMNET, StatusCode::MISMATCHED_SHAPE,
24+
"Unknown error when deducing the layout.");
25+
}
26+
}
27+
for (int i = std::min(a.ndim, b.ndim); i < a.ndim; i++) {
28+
res.shape[dim - i - 1] = a.shape[a.ndim - i - 1];
29+
}
30+
for (int i = std::min(a.ndim, b.ndim); i < b.ndim; i++) {
31+
res.shape[dim - i - 1] = b.shape[b.ndim - i - 1];
32+
}
33+
res.dtype = DType::from_enum(
34+
deduce_double_input_op(a.dtype.enumv(), b.dtype.enumv()));
35+
res.ndim = dim;
36+
a.broadcast_inplace(res);
37+
b.broadcast_inplace(res);
38+
return Status::OK();
39+
}
40+
41+
} // namespace opr
42+
} // namespace nncore

core/op/common/ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ using namespace param;
2020
cb(transpose) cb(permute) cb(repeat) cb(flip) cb(matrix_inverse) cb(rotate) \
2121
cb(pad) cb(sort) cb(onehot) cb(sum) cb(max) cb(min) cb(negative)
2222

23-
#define NN_FOREACH_DOUBLE_INPUT_OP(cb) cb(matmul) cb(dot) cb(boolindex)
23+
#define NN_FOREACH_DOUBLE_INPUT_OP(cb) \
24+
cb(matmul) cb(dot) cb(boolindex) cb(interelem)
2425

2526
#define NN_FOREACH_SINGLE_INPUT_OP_WITH_PARAM(cb, ...) \
2627
cb(permute, __VA_ARGS__) cb(transpose, __VA_ARGS__)

core/op/common/param.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ struct min {
208208

209209
struct negative {};
210210

211+
struct interelem {
212+
enum Operation { Add = 1, Sub = 2, Mul = 3, Div = 4 };
213+
214+
Operation op;
215+
216+
interelem(Operation op) : op(op) {}
217+
};
218+
211219
struct argmxx {
212220
int axis;
213221
bool is_max;

core/op/naive/interelem.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include <functional>
2+
3+
#include "core/op/naive/ops.h"
4+
5+
using Param = nncore::param::interelem;
6+
7+
namespace nncore {
8+
namespace opr {
9+
namespace naive {
10+
11+
template <typename TA, typename TB, typename TC>
12+
TC interelem_add(TA a, TB b) {
13+
return static_cast<TC>(a + b);
14+
}
15+
16+
template <typename TA, typename TB, typename TC>
17+
TC interelem_sub(TA a, TB b) {
18+
return static_cast<TC>(a - b);
19+
}
20+
21+
template <typename TA, typename TB, typename TC>
22+
TC interelem_mul(TA a, TB b) {
23+
return static_cast<TC>(a * b);
24+
}
25+
26+
template <typename TA, typename TB, typename TC>
27+
TC interelem_div(TA a, TB b) {
28+
return static_cast<TC>(a / b);
29+
}
30+
31+
IMPL_NAIVE_DOUBLE_INPUT_INTERNAL(interelem) {
32+
nn_size idx_offset = loup.ndim - NN_MAX_NDIM;
33+
std::function<TC(TA, TB)> func;
34+
switch (param.op) {
35+
case Param::Add:
36+
func = interelem_add<TA, TB, TC>;
37+
break;
38+
case Param::Sub:
39+
func = interelem_sub<TA, TB, TC>;
40+
break;
41+
case Param::Mul:
42+
func = interelem_mul<TA, TB, TC>;
43+
break;
44+
case Param::Div:
45+
func = interelem_div<TA, TB, TC>;
46+
break;
47+
48+
default:
49+
return Status(StatusCategory::NUMNET, StatusCode::FAIL,
50+
"Invalid types in interelem op.");
51+
}
52+
for (nn_size n = 0; n < (idx_offset == 0 ? loup[idx_offset] : 1); n++) {
53+
nn_size n_offset_a = n * la.stride[idx_offset];
54+
nn_size n_offset_b = n * lb.stride[idx_offset];
55+
nn_size n_offset_oup = n * loup.stride[idx_offset];
56+
for (nn_size c = 0; c < (idx_offset >= -1 ? loup[idx_offset + 1] : 1);
57+
c++) {
58+
nn_size nc_offset_a = c * la.stride[idx_offset + 1] + n_offset_a;
59+
nn_size nc_offset_b = c * lb.stride[idx_offset + 1] + n_offset_b;
60+
nn_size nc_offset_oup = c * loup.stride[idx_offset + 1] + n_offset_oup;
61+
for (nn_size i = 0; i < (idx_offset >= -2 ? loup[idx_offset + 2] : 1);
62+
i++) {
63+
nn_size nch_offset_a = i * la.stride[idx_offset + 2] + nc_offset_a;
64+
nn_size nch_offset_b = i * lb.stride[idx_offset + 2] + nc_offset_b;
65+
nn_size nch_offset_oup =
66+
i * loup.stride[idx_offset + 2] + nc_offset_oup;
67+
for (nn_size j = 0; j < loup[idx_offset + 3]; j++) {
68+
nn_size a_pos = nch_offset_a + j * la.stride[idx_offset + 3];
69+
nn_size b_pos = nch_offset_b + j * lb.stride[idx_offset + 3];
70+
nn_size oup_pos = nch_offset_oup + j * loup.stride[idx_offset + 3];
71+
ptr_oup[oup_pos] = func(ptr_a[a_pos], ptr_b[b_pos]);
72+
}
73+
}
74+
}
75+
}
76+
return Status::OK();
77+
}
78+
79+
} // namespace naive
80+
} // namespace opr
81+
82+
} // namespace nncore

core/op/naive/negative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace nncore {
44
namespace opr {
55
namespace naive {
66

7-
IMPL_NAIVE_SINGLE_INPUT_INTERNAL(flip) {
7+
IMPL_NAIVE_SINGLE_INPUT_INTERNAL(negative) {
88
nn_size n = loup.total_elems();
99
nn_size src_idx[NN_MAX_NDIM];
1010
for (nn_size i = 0; i < n; i++) {

core/test/naive/interelem.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "core/op/naive/ops.h"
2+
#include "core/test/common/factory.h"
3+
#include "core/test/common/utils.h"
4+
#include "gtest/gtest.h"
5+
6+
using namespace nncore;
7+
using namespace test;
8+
using namespace opr;
9+
using namespace opr::naive;
10+
11+
using F = NDArrayFactory;
12+
using Param = param::interelem;
13+
14+
TEST(Naive, Interelem) {
15+
OpBase* oprs = OpNaiveImpl::get_instance();
16+
17+
// Group 1
18+
Tensor a1 = F::from_list({1, 2, 3, 4, 5, 6}, {2, 3}, dtype::Int32());
19+
Tensor b1 = F::from_list({-1, 1, 2, 1, -1, 3}, {2, 3}, dtype::Int32());
20+
Tensor truth1 = F::from_list({-1, 2, 6, 4, -5, 18}, {2, 3}, dtype::Int32());
21+
Param p1(Param::Operation::Mul);
22+
23+
Tensor pred1;
24+
ASSERT_TRUE(oprs->interelem(a1, b1, pred1, p1).is_ok());
25+
assert_same_data<int>(pred1, truth1, 0.0001f);
26+
27+
// Group 2
28+
Tensor a2 = F::from_list({1, 3, 5, 7, 9, 0}, {2, 3}, dtype::Int32());
29+
Tensor b2 = F::from_list({1, 2, 3}, {3}, dtype::Int32());
30+
Tensor truth2 = F::from_list({1, 6, 15, 7, 18, 0}, {2, 3}, dtype::Int32());
31+
Param p2(Param::Operation::Mul);
32+
33+
Tensor pred2;
34+
ASSERT_TRUE(oprs->interelem(a2, b2, pred2, p2).is_ok());
35+
assert_same_data<int>(pred2, truth2, 0.0001f);
36+
37+
// Group 4
38+
Tensor a4 = F::from_list({1, 2, 5, -2, -4, 6, 1, 2, 5, -2, -4, 6,
39+
1, 2, 5, -2, -4, 6, 1, 2, 5, -2, -4, 6},
40+
{1, 4, 2, 3}, dtype::Int32());
41+
Tensor b4 = F::from_list({1, 1, 2, 3, -2, -4, 8, 15, -7, -1, 5, 0},
42+
{1, 4, 1, 3}, dtype::Int32());
43+
Tensor truth4 =
44+
F::from_list({1, 2, 10, -2, -4, 12, 3, -4, -20, -6, 8, -24,
45+
8, 30, -35, -16, -60, -42, -1, 10, 0, 2, -20, 0},
46+
{1, 4, 2, 3}, dtype::Int32());
47+
Param p4(Param::Operation::Mul);
48+
49+
Tensor pred4;
50+
ASSERT_TRUE(oprs->interelem(a4, b4, pred4, p4).is_ok());
51+
assert_same_data<int>(pred4, truth4, 0.0001f);
52+
}

0 commit comments

Comments
 (0)