11using Tensornet . Exceptions ;
2+ using Tensornet . Native ;
3+ using Tensornet . Native . Param ;
24
35namespace Tensornet . Common {
4- public static class InterElemOperation {
6+ internal enum InterElemOperationType {
7+ Add = 1 ,
8+ Sub = 2 ,
9+ Mul = 3 ,
10+ Div = 4
11+ }
12+ internal static class InterElemOperation {
13+ public static unsafe Tensor < T > Execute < T > ( Tensor < T > a , Tensor < T > b , InterElemOperationType operationType )
14+ where T : struct , IEquatable < T > , IConvertible {
15+ TensorLayout resLayout = new TensorLayout ( ) ;
16+ resLayout . DType = TensorTypeInfo . GetTypeInfo ( typeof ( T ) ) . _dtype ;
17+ resLayout . NDim = System . Math . Max ( a . TLayout . NDim , b . TLayout . NDim ) ;
18+ for ( int i = a . TLayout . NDim - 1 , j = b . TLayout . NDim - 1 , idx = resLayout . NDim - 1 ; i >= 0 || j >= 0 ; i -- , j -- , idx -- ) {
19+ if ( i < 0 ) {
20+ resLayout . Shape [ idx ] = b . TLayout . Shape [ j ] ;
21+ }
22+ else if ( j < 0 ) {
23+ resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
24+ }
25+ else if ( a . TLayout . Shape [ i ] == b . TLayout . Shape [ j ] ) {
26+ resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
27+ }
28+ else if ( a . TLayout . Shape [ i ] == 1 ) {
29+ resLayout . Shape [ idx ] = b . TLayout . Shape [ j ] ;
30+ }
31+ else if ( b . TLayout . Shape [ j ] == 1 ) {
32+ resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
33+ }
34+ else {
35+ throw new MismatchedShapeException ( $ "Cannot broadcast between the shape { a . TLayout as TensorShape } and shape { b . TLayout as TensorShape } .") ;
36+ }
37+ }
38+ resLayout . InitContiguousLayout ( ) ;
39+ Tensor < T > tempA = a . Broadcast ( resLayout ) ;
40+ Tensor < T > tempB = b . Broadcast ( resLayout ) ;
41+ Tensor < T > res = new Tensor < T > ( resLayout ) ;
42+ InterelemParam p = new InterelemParam ( ) { operationType = operationType } ;
43+ IntPtr status = NativeExecutor . Execute ( NativeApi . Interelem , tempA . TMemory , tempB . TMemory , res . TMemory , tempA . TLayout , tempB . TLayout , res . TLayout , new IntPtr ( & p ) , Tensor < T > . Provider ) ;
44+ NativeStatus . AssertOK ( status ) ;
45+ return res ;
46+ }
47+ public static unsafe Tensor < TResult > Execute < TA , TB , TResult > ( Tensor < TA > a , Tensor < TB > b , InterElemOperationType operationType )
48+ where TA : struct , IEquatable < TA > , IConvertible
49+ where TB : struct , IEquatable < TB > , IConvertible
50+ where TResult : struct , IEquatable < TResult > , IConvertible {
51+ TensorLayout resLayout = new TensorLayout ( ) ;
52+ resLayout . DType = TensorTypeInfo . GetTypeInfo ( typeof ( TResult ) ) . _dtype ;
53+ resLayout . NDim = System . Math . Max ( a . TLayout . NDim , b . TLayout . NDim ) ;
54+ for ( int i = a . TLayout . NDim - 1 , j = b . TLayout . NDim - 1 , idx = resLayout . NDim - 1 ; i >= 0 || j >= 0 ; i -- , j -- , idx -- ) {
55+ if ( i < 0 ) {
56+ resLayout . Shape [ idx ] = b . TLayout . Shape [ j ] ;
57+ }
58+ else if ( j < 0 ) {
59+ resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
60+ }
61+ else if ( a . TLayout . Shape [ i ] == b . TLayout . Shape [ j ] ) {
62+ resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
63+ }
64+ else if ( a . TLayout . Shape [ i ] == 1 ) {
65+ resLayout . Shape [ idx ] = b . TLayout . Shape [ j ] ;
66+ }
67+ else if ( b . TLayout . Shape [ j ] == 1 ) {
68+ resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
69+ }
70+ else {
71+ throw new MismatchedShapeException ( $ "Cannot broadcast between the shape { a . TLayout as TensorShape } and shape { b . TLayout as TensorShape } .") ;
72+ }
73+ }
74+ resLayout . InitContiguousLayout ( ) ;
75+ Tensor < TA > tempA = a . Broadcast ( resLayout ) ;
76+ Tensor < TB > tempB = b . Broadcast ( resLayout ) ;
77+ Tensor < TResult > res = new Tensor < TResult > ( resLayout ) ;
78+ InterelemParam p = new InterelemParam ( ) { operationType = operationType } ;
79+ IntPtr status = NativeExecutor . Execute ( NativeApi . Interelem , tempA . TMemory , tempB . TMemory , res . TMemory , tempA . TLayout , tempB . TLayout , res . TLayout , new IntPtr ( & p ) , Tensor < TResult > . Provider ) ;
80+ NativeStatus . AssertOK ( status ) ;
81+ return res ;
82+ }
583 public static Tensor < TResult > Execute < TA , TB , TResult > ( Tensor < TA > a , Tensor < TB > b , Func < TA , TB , TResult > operation )
684 where TA : struct , IEquatable < TA > , IConvertible
785 where TB : struct , IEquatable < TB > , IConvertible
@@ -26,7 +104,7 @@ public static Tensor<TResult> Execute<TA, TB, TResult>(Tensor<TA> a, Tensor<TB>
26104 resLayout . Shape [ idx ] = a . TLayout . Shape [ i ] ;
27105 }
28106 else {
29- throw new MismatchedShapeException ( $ "Cannot broadcast between the shape { a . TLayout as TensorShape } and shape {{ b.TLayout as TensorShape} }.") ;
107+ throw new MismatchedShapeException ( $ "Cannot broadcast between the shape { a . TLayout as TensorShape } and shape { b . TLayout as TensorShape } .") ;
30108 }
31109 }
32110 resLayout . InitContiguousLayout ( ) ;
0 commit comments