1+ #include " core/op/naive/ops.h"
2+ #include " core/test/common/factory.h"
3+ #include " core/test/common/utils.h"
4+ #include " gtest/gtest.h"
5+
6+ using namespace nncore ;
7+ using namespace test ;
8+ using namespace opr ;
9+ using namespace opr ::naive;
10+
11+ using F = NDArrayFactory;
12+ using Param = param::interelem;
13+
14+ TEST (Naive, Interelem) {
15+ OpBase* oprs = OpNaiveImpl::get_instance ();
16+
17+ // Group 1
18+ Tensor a1 = F::from_list ({1 , 2 , 3 , 4 , 5 , 6 }, {2 , 3 }, dtype::Int32 ());
19+ Tensor b1 = F::from_list ({-1 , 1 , 2 , 1 , -1 , 3 }, {2 , 3 }, dtype::Int32 ());
20+ Tensor truth1 = F::from_list ({-1 , 2 , 6 , 4 , -5 , 18 }, {2 , 3 }, dtype::Int32 ());
21+ Param p1 (Param::Operation::Mul);
22+
23+ Tensor pred1;
24+ ASSERT_TRUE (oprs->interelem (a1, b1, pred1, p1).is_ok ());
25+ assert_same_data<int >(pred1, truth1, 0 .0001f );
26+
27+ // Group 2
28+ Tensor a2 = F::from_list ({1 , 3 , 5 , 7 , 9 , 0 }, {2 , 3 }, dtype::Int32 ());
29+ Tensor b2 = F::from_list ({1 , 2 , 3 }, {3 }, dtype::Int32 ());
30+ Tensor truth2 = F::from_list ({1 , 6 , 15 , 7 , 18 , 0 }, {2 , 3 }, dtype::Int32 ());
31+ Param p2 (Param::Operation::Mul);
32+
33+ Tensor pred2;
34+ ASSERT_TRUE (oprs->interelem (a2, b2, pred2, p2).is_ok ());
35+ assert_same_data<int >(pred2, truth2, 0 .0001f );
36+
37+ // Group 4
38+ Tensor a4 = F::from_list ({1 , 2 , 5 , -2 , -4 , 6 , 1 , 2 , 5 , -2 , -4 , 6 ,
39+ 1 , 2 , 5 , -2 , -4 , 6 , 1 , 2 , 5 , -2 , -4 , 6 },
40+ {1 , 4 , 2 , 3 }, dtype::Int32 ());
41+ Tensor b4 = F::from_list ({1 , 1 , 2 , 3 , -2 , -4 , 8 , 15 , -7 , -1 , 5 , 0 },
42+ {1 , 4 , 1 , 3 }, dtype::Int32 ());
43+ Tensor truth4 =
44+ F::from_list ({1 , 2 , 10 , -2 , -4 , 12 , 3 , -4 , -20 , -6 , 8 , -24 ,
45+ 8 , 30 , -35 , -16 , -60 , -42 , -1 , 10 , 0 , 2 , -20 , 0 },
46+ {1 , 4 , 2 , 3 }, dtype::Int32 ());
47+ Param p4 (Param::Operation::Mul);
48+
49+ Tensor pred4;
50+ ASSERT_TRUE (oprs->interelem (a4, b4, pred4, p4).is_ok ());
51+ assert_same_data<int >(pred4, truth4, 0 .0001f );
52+ }
0 commit comments