Skip to content

Commit e1faabd

Browse files
committed
fix(fix the bug in interelem op).
1 parent d10aa8a commit e1faabd

File tree

3 files changed

+98
-7
lines changed

3 files changed

+98
-7
lines changed

core/op/naive/interelem.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,80 @@ IMPL_NAIVE_DOUBLE_INPUT_INTERNAL(interelem) {
6565
func = interelem_div<TA, TB, TC>;
6666
break;
6767
case Param::Mod:
68-
func = interelem_mod<TA, TB, TC>;
68+
if (la.dtype.enumv() == DTypeEnum::Float32 ||
69+
la.dtype.enumv() == DTypeEnum::Float64 ||
70+
lb.dtype.enumv() == DTypeEnum::Float32 ||
71+
lb.dtype.enumv() == DTypeEnum::Float64) {
72+
return Status(StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
73+
"The mod operator only accepts int or long as input");
74+
}
75+
if (loup.dtype.enumv() == DTypeEnum::Int32) {
76+
func = interelem_mod<nn_int32, nn_int32, nn_int32>;
77+
} else if (loup.dtype.enumv() == DTypeEnum::Int64) {
78+
func = interelem_mod<nn_int64, nn_int64, nn_int64>;
79+
} else if (loup.dtype.enumv() == DTypeEnum::Bool) {
80+
func = interelem_mod<bool, bool, bool>;
81+
} else {
82+
return Status(StatusCategory::SYSTEM, StatusCode::FAIL,
83+
"Unexpected exception in interelem.cpp.");
84+
}
6985
break;
7086
case Param::And:
71-
func = interelem_and<TA, TB, TC>;
87+
if (la.dtype.enumv() == DTypeEnum::Float32 ||
88+
la.dtype.enumv() == DTypeEnum::Float64 ||
89+
lb.dtype.enumv() == DTypeEnum::Float32 ||
90+
lb.dtype.enumv() == DTypeEnum::Float64) {
91+
return Status(StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
92+
"The and operator only accepts int or long as input");
93+
}
94+
if (loup.dtype.enumv() == DTypeEnum::Int32) {
95+
func = interelem_and<nn_int32, nn_int32, nn_int32>;
96+
} else if (loup.dtype.enumv() == DTypeEnum::Int64) {
97+
func = interelem_and<nn_int64, nn_int64, nn_int64>;
98+
} else if (loup.dtype.enumv() == DTypeEnum::Bool) {
99+
func = interelem_and<bool, bool, bool>;
100+
} else {
101+
return Status(StatusCategory::SYSTEM, StatusCode::FAIL,
102+
"Unexpected exception in interelem.cpp.");
103+
}
72104
break;
73105
case Param::Or:
74-
func = interelem_or<TA, TB, TC>;
106+
if (la.dtype.enumv() == DTypeEnum::Float32 ||
107+
la.dtype.enumv() == DTypeEnum::Float64 ||
108+
lb.dtype.enumv() == DTypeEnum::Float32 ||
109+
lb.dtype.enumv() == DTypeEnum::Float64) {
110+
return Status(StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
111+
"The or operator only accepts int or long as input");
112+
}
113+
if (loup.dtype.enumv() == DTypeEnum::Int32) {
114+
func = interelem_or<nn_int32, nn_int32, nn_int32>;
115+
} else if (loup.dtype.enumv() == DTypeEnum::Int64) {
116+
func = interelem_or<nn_int64, nn_int64, nn_int64>;
117+
} else if (loup.dtype.enumv() == DTypeEnum::Bool) {
118+
func = interelem_or<bool, bool, bool>;
119+
} else {
120+
return Status(StatusCategory::SYSTEM, StatusCode::FAIL,
121+
"Unexpected exception in interelem.cpp.");
122+
}
75123
break;
76124
case Param::Xor:
77-
func = interelem_xor<TA, TB, TC>;
125+
if (la.dtype.enumv() == DTypeEnum::Float32 ||
126+
la.dtype.enumv() == DTypeEnum::Float64 ||
127+
lb.dtype.enumv() == DTypeEnum::Float32 ||
128+
lb.dtype.enumv() == DTypeEnum::Float64) {
129+
return Status(StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
130+
"The xor operator only accepts int or long as input");
131+
}
132+
if (loup.dtype.enumv() == DTypeEnum::Int32) {
133+
func = interelem_xor<nn_int32, nn_int32, nn_int32>;
134+
} else if (loup.dtype.enumv() == DTypeEnum::Int64) {
135+
func = interelem_xor<nn_int64, nn_int64, nn_int64>;
136+
} else if (loup.dtype.enumv() == DTypeEnum::Bool) {
137+
func = interelem_xor<bool, bool, bool>;
138+
} else {
139+
return Status(StatusCategory::SYSTEM, StatusCode::FAIL,
140+
"Unexpected exception in interelem.cpp.");
141+
}
78142
break;
79143

80144
default:

core/test/naive/interelem.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,24 @@ TEST(Naive, Interelem) {
4949
Tensor pred4;
5050
ASSERT_TRUE(oprs->interelem(a4, b4, pred4, p4).is_ok());
5151
assert_same_data<int>(pred4, truth4, 0.0001f);
52+
53+
// Group 5
54+
Tensor a5 = F::from_list({1, 2, 35, 43, 5, 14}, {2, 3}, dtype::Int32());
55+
Tensor b5 = F::from_list({1, 2, 9, 5, 2, 4}, {2, 3}, dtype::Int32());
56+
Tensor truth5 = F::from_list({0, 0, 8, 3, 1, 2}, {2, 3}, dtype::Int32());
57+
Param p5(Param::Operation::Mod);
58+
59+
Tensor pred5;
60+
ASSERT_TRUE(oprs->interelem(a5, b5, pred5, p5).is_ok());
61+
assert_same_data<int>(pred5, truth5, 0.0001f);
62+
63+
// Group 5
64+
Tensor a6 = F::from_list({1, 2, 4}, {3}, dtype::Int32());
65+
Tensor b6 = F::from_list({1, 3, 1}, {3}, dtype::Int32());
66+
Tensor truth6 = F::from_list({0, 1, 5}, {3}, dtype::Int32());
67+
Param p6(Param::Operation::Xor);
68+
69+
Tensor pred6;
70+
ASSERT_TRUE(oprs->interelem(a6, b6, pred6, p6).is_ok());
71+
assert_same_data<int>(pred6, truth6, 0.0001f);
5272
}

csharp/Tensor.NET/Tensor.NET.csproj

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
Github repository: https://github.com/AsakusaRinne/Tensor.NET
1616

1717
Email: AsakusaRinne@gmail.com
18+
19+
Update information:
20+
1. Add static method for load and save.
21+
2. Add mod, and, or, xor operators for Tensor.
22+
3. Add ForEach method for Tensor.
23+
4. Add docs for public APIs.
24+
1825
</Description>
19-
<PackageLicenseExpression>Apache2.0</PackageLicenseExpression>
26+
<PackageLicenseExpression>Apache-2.0</PackageLicenseExpression>
2027
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
2128
</PropertyGroup>
2229

2330
<ItemGroup>
24-
<Content Include="./CppLibrary/libnumnet.dll">
31+
<Content Include="./CppLibrary/libtensornet.dll">
2532
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
2633
<Pack>true</Pack>
2734
<PackagePath>lib\$(TargetFramework)</PackagePath>
@@ -36,7 +43,7 @@
3643
<Pack>true</Pack>
3744
<PackagePath>lib\$(TargetFramework)</PackagePath>
3845
</Content>
39-
<Content Include="libnumnet.so">
46+
<Content Include="libtensornet.so">
4047
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
4148
<PackageCopyToOutput>true</PackageCopyToOutput>
4249
<pack>true</pack>

0 commit comments

Comments
 (0)