Skip to content

Commit 0dda694

Browse files
committed
feat(csharp/Tensor.NET): add interelem method and api.
1 parent 33cc3fb commit 0dda694

File tree

9 files changed

+168
-2
lines changed

9 files changed

+168
-2
lines changed

apis/numnet_c_cxx_apis.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,25 @@ Status *Dot(NativeTensor *a, NativeTensor *b, NativeTensor *oup,
122122
}
123123
}
124124

125+
Status *Interelem(NativeTensor *a, NativeTensor *b, NativeTensor *oup,
126+
param::interelem *param, ProviderEnum provider) {
127+
Tensor t_a, t_b, t_oup;
128+
a->ToTensor(t_a, false);
129+
b->ToTensor(t_b, false);
130+
oup->ToTensor(t_oup, true);
131+
OpBase *impl = GetImpl(provider);
132+
if (impl == nullptr) {
133+
return new Status(StatusCategory::NUMNET, StatusCode::INVALID_ARGUMENT,
134+
"Unsupported provider.");
135+
}
136+
auto status = impl->interelem(t_a, t_b, t_oup, *param);
137+
if (status.is_ok()) {
138+
return nullptr;
139+
} else {
140+
return new Status(status);
141+
}
142+
}
143+
125144
Status *BoolIndex(NativeTensor *a, NativeTensor *b, NativeTensor *oup,
126145
param::boolindex *param, ProviderEnum provider) {
127146
printf("Enter!\n");

apis/numnet_c_cxx_apis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ Status *Dot(NativeTensor *a, NativeTensor *b, NativeTensor *oup,
5151
Status *BoolIndex(NativeTensor *a, NativeTensor *b, NativeTensor *oup,
5252
param::boolindex *param, ProviderEnum provider);
5353

54+
Status *Interelem(NativeTensor *a, NativeTensor *b, NativeTensor *oup,
55+
param::interelem *param, ProviderEnum provider);
56+
5457
Status *Permute(NativeTensor *inp, NativeTensor *oup, param::permute *param,
5558
ProviderEnum provider);
5659

csharp/Tensor.NET/Native/NativeApi.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ internal static class NativeApi{
1616
public static extern IntPtr Dot(IntPtr a, IntPtr b, IntPtr oup, IntPtr param, NativeProvider provider);
1717
[DllImport("libnumnet")]
1818
public static extern IntPtr BoolIndex(IntPtr a, IntPtr b, IntPtr oup, IntPtr param, NativeProvider provider);
19+
[DllImport("libnumnet")]
20+
public static extern IntPtr Interelem(IntPtr a, IntPtr b, IntPtr oup, IntPtr param, NativeProvider provider);
1921

2022

2123
[DllImport("libnumnet")]

csharp/Tensor.NET/Native/NativeParam.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
using Tensornet.Exceptions;
2+
using Tensornet.Common;
23

34
namespace Tensornet.Native.Param{
45
internal struct MatmulParam{
56

67
}
78
internal struct DotParam{
89

10+
}
11+
internal struct InterelemParam{
12+
internal InterElemOperationType operationType;
913
}
1014
internal struct PermuteParam{
1115
internal IntPtr dims;

csharp/Tensor.NET/Operators/Add.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Tensornet.Common;
2+
3+
namespace Tensornet{
4+
public partial class Tensor<T>{
5+
public static Tensor<T> operator+(Tensor<T> lhs, Tensor<T> rhs){
6+
return InterElemOperation.Execute<T>(lhs, rhs, InterElemOperationType.Add);
7+
}
8+
public static Tensor<T> operator+(Tensor<T> lhs, T rhs){
9+
return InterElemOperation.Execute<T>(lhs, (Tensor<T>)rhs, InterElemOperationType.Add);
10+
}
11+
public static Tensor<T> operator+(T lhs, Tensor<T> rhs){
12+
return InterElemOperation.Execute<T>((Tensor<T>)lhs, rhs, InterElemOperationType.Add);
13+
}
14+
}
15+
}

csharp/Tensor.NET/Operators/Div.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Tensornet.Common;
2+
3+
namespace Tensornet{
4+
public partial class Tensor<T>{
5+
public static Tensor<T> operator/(Tensor<T> lhs, Tensor<T> rhs){
6+
return InterElemOperation.Execute<T>(lhs, rhs, InterElemOperationType.Div);
7+
}
8+
public static Tensor<T> operator/(Tensor<T> lhs, T rhs){
9+
return InterElemOperation.Execute<T>(lhs, (Tensor<T>)rhs, InterElemOperationType.Div);
10+
}
11+
public static Tensor<T> operator/(T lhs, Tensor<T> rhs){
12+
return InterElemOperation.Execute<T>((Tensor<T>)lhs, rhs, InterElemOperationType.Div);
13+
}
14+
}
15+
}

csharp/Tensor.NET/Operators/Mul.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Tensornet.Common;
2+
3+
namespace Tensornet{
4+
public partial class Tensor<T>{
5+
public static Tensor<T> operator*(Tensor<T> lhs, Tensor<T> rhs){
6+
return InterElemOperation.Execute<T>(lhs, rhs, InterElemOperationType.Mul);
7+
}
8+
public static Tensor<T> operator*(Tensor<T> lhs, T rhs){
9+
return InterElemOperation.Execute<T>(lhs, (Tensor<T>)rhs, InterElemOperationType.Mul);
10+
}
11+
public static Tensor<T> operator*(T lhs, Tensor<T> rhs){
12+
return InterElemOperation.Execute<T>((Tensor<T>)lhs, rhs, InterElemOperationType.Mul);
13+
}
14+
}
15+
}

csharp/Tensor.NET/Operators/Sub.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Tensornet.Common;
2+
3+
namespace Tensornet{
4+
public partial class Tensor<T>{
5+
public static Tensor<T> operator-(Tensor<T> lhs, Tensor<T> rhs){
6+
return InterElemOperation.Execute<T>(lhs, rhs, InterElemOperationType.Sub);
7+
}
8+
public static Tensor<T> operator-(Tensor<T> lhs, T rhs){
9+
return InterElemOperation.Execute<T>(lhs, (Tensor<T>)rhs, InterElemOperationType.Sub);
10+
}
11+
public static Tensor<T> operator-(T lhs, Tensor<T> rhs){
12+
return InterElemOperation.Execute<T>((Tensor<T>)lhs, rhs, InterElemOperationType.Sub);
13+
}
14+
}
15+
}

csharp/Tensor.NET/Tensor/Common/InterElemOperation.cs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,85 @@
11
using Tensornet.Exceptions;
2+
using Tensornet.Native;
3+
using Tensornet.Native.Param;
24

35
namespace Tensornet.Common{
4-
public static class InterElemOperation{
6+
internal enum InterElemOperationType{
7+
Add = 1,
8+
Sub = 2,
9+
Mul = 3,
10+
Div = 4
11+
}
12+
internal static class InterElemOperation{
13+
public static unsafe Tensor<T> Execute<T>(Tensor<T> a, Tensor<T> b, InterElemOperationType operationType)
14+
where T : struct, IEquatable<T>, IConvertible{
15+
TensorLayout resLayout = new TensorLayout();
16+
resLayout.DType = TensorTypeInfo.GetTypeInfo(typeof(T))._dtype;
17+
resLayout.NDim = System.Math.Max(a.TLayout.NDim, b.TLayout.NDim);
18+
for (int i = a.TLayout.NDim - 1, j = b.TLayout.NDim - 1, idx = resLayout.NDim - 1; i >= 0 || j >= 0; i--, j--, idx--){
19+
if(i < 0){
20+
resLayout.Shape[idx] = b.TLayout.Shape[j];
21+
}
22+
else if(j < 0){
23+
resLayout.Shape[idx] = a.TLayout.Shape[i];
24+
}
25+
else if(a.TLayout.Shape[i] == b.TLayout.Shape[j]){
26+
resLayout.Shape[idx] = a.TLayout.Shape[i];
27+
}
28+
else if(a.TLayout.Shape[i] == 1){
29+
resLayout.Shape[idx] = b.TLayout.Shape[j];
30+
}
31+
else if(b.TLayout.Shape[j] == 1){
32+
resLayout.Shape[idx] = a.TLayout.Shape[i];
33+
}
34+
else{
35+
throw new MismatchedShapeException($"Cannot broadcast between the shape {a.TLayout as TensorShape} and shape {b.TLayout as TensorShape}.");
36+
}
37+
}
38+
resLayout.InitContiguousLayout();
39+
Tensor<T> tempA = a.Broadcast(resLayout);
40+
Tensor<T> tempB = b.Broadcast(resLayout);
41+
Tensor<T> res = new Tensor<T>(resLayout);
42+
InterelemParam p = new InterelemParam() { operationType = operationType };
43+
IntPtr status = NativeExecutor.Execute(NativeApi.Interelem, tempA.TMemory, tempB.TMemory, res.TMemory, tempA.TLayout, tempB.TLayout, res.TLayout, new IntPtr(&p), Tensor<T>.Provider);
44+
NativeStatus.AssertOK(status);
45+
return res;
46+
}
47+
public static unsafe Tensor<TResult> Execute<TA, TB, TResult>(Tensor<TA> a, Tensor<TB> b, InterElemOperationType operationType)
48+
where TA : struct, IEquatable<TA>, IConvertible
49+
where TB : struct, IEquatable<TB>, IConvertible
50+
where TResult : struct, IEquatable<TResult>, IConvertible{
51+
TensorLayout resLayout = new TensorLayout();
52+
resLayout.DType = TensorTypeInfo.GetTypeInfo(typeof(TResult))._dtype;
53+
resLayout.NDim = System.Math.Max(a.TLayout.NDim, b.TLayout.NDim);
54+
for (int i = a.TLayout.NDim - 1, j = b.TLayout.NDim - 1, idx = resLayout.NDim - 1; i >= 0 || j >= 0; i--, j--, idx--){
55+
if(i < 0){
56+
resLayout.Shape[idx] = b.TLayout.Shape[j];
57+
}
58+
else if(j < 0){
59+
resLayout.Shape[idx] = a.TLayout.Shape[i];
60+
}
61+
else if(a.TLayout.Shape[i] == b.TLayout.Shape[j]){
62+
resLayout.Shape[idx] = a.TLayout.Shape[i];
63+
}
64+
else if(a.TLayout.Shape[i] == 1){
65+
resLayout.Shape[idx] = b.TLayout.Shape[j];
66+
}
67+
else if(b.TLayout.Shape[j] == 1){
68+
resLayout.Shape[idx] = a.TLayout.Shape[i];
69+
}
70+
else{
71+
throw new MismatchedShapeException($"Cannot broadcast between the shape {a.TLayout as TensorShape} and shape {b.TLayout as TensorShape}.");
72+
}
73+
}
74+
resLayout.InitContiguousLayout();
75+
Tensor<TA> tempA = a.Broadcast(resLayout);
76+
Tensor<TB> tempB = b.Broadcast(resLayout);
77+
Tensor<TResult> res = new Tensor<TResult>(resLayout);
78+
InterelemParam p = new InterelemParam() { operationType = operationType };
79+
IntPtr status = NativeExecutor.Execute(NativeApi.Interelem, tempA.TMemory, tempB.TMemory, res.TMemory, tempA.TLayout, tempB.TLayout, res.TLayout, new IntPtr(&p), Tensor<TResult>.Provider);
80+
NativeStatus.AssertOK(status);
81+
return res;
82+
}
583
public static Tensor<TResult> Execute<TA, TB, TResult>(Tensor<TA> a, Tensor<TB> b, Func<TA, TB, TResult> operation)
684
where TA : struct, IEquatable<TA>, IConvertible
785
where TB : struct, IEquatable<TB>, IConvertible
@@ -26,7 +104,7 @@ public static Tensor<TResult> Execute<TA, TB, TResult>(Tensor<TA> a, Tensor<TB>
26104
resLayout.Shape[idx] = a.TLayout.Shape[i];
27105
}
28106
else{
29-
throw new MismatchedShapeException($"Cannot broadcast between the shape {a.TLayout as TensorShape} and shape {{b.TLayout as TensorShape}}.");
107+
throw new MismatchedShapeException($"Cannot broadcast between the shape {a.TLayout as TensorShape} and shape {b.TLayout as TensorShape}.");
30108
}
31109
}
32110
resLayout.InitContiguousLayout();

0 commit comments

Comments
 (0)