@@ -6,31 +6,79 @@ namespace OnnxStack.Core
66 public static class TensorExtension
77 {
88 /// <summary>
9- /// Concatenates the specified tensors along the 0 axis.
9+ /// Concatenates the specified tensors along the specified axis.
1010 /// </summary>
1111 /// <param name="tensor1">The tensor1.</param>
1212 /// <param name="tensor2">The tensor2.</param>
1313 /// <param name="axis">The axis.</param>
1414 /// <returns></returns>
15- /// <exception cref="System.NotImplementedException">Only axis 0 is supported</exception>
15+ /// <exception cref="System.NotImplementedException">Only axis 0,1,2 is supported</exception>
1616 public static DenseTensor < float > Concatenate ( this DenseTensor < float > tensor1 , DenseTensor < float > tensor2 , int axis = 0 )
1717 {
1818 if ( tensor1 == null )
1919 return tensor2 . ToDenseTensor ( ) ;
2020
21- if ( axis != 0 && axis != 2 )
22- throw new NotImplementedException ( "Only axis 0, 2 is supported" ) ;
21+ return axis switch
22+ {
23+ 0 => ConcatenateAxis0 ( tensor1 , tensor2 ) ,
24+ 1 => ConcatenateAxis1 ( tensor1 , tensor2 ) ,
25+ 2 => ConcatenateAxis2 ( tensor1 , tensor2 ) ,
26+ _ => throw new NotImplementedException ( "Only axis 0, 1, 2 is supported" )
27+ } ;
28+ }
2329
24- if ( axis == 2 )
25- return Concatenate ( tensor1 , tensor2 ) ;
2630
31+ private static DenseTensor < float > ConcatenateAxis0 ( this DenseTensor < float > tensor1 , DenseTensor < float > tensor2 )
32+ {
2733 var dimensions = tensor1 . Dimensions . ToArray ( ) ;
2834 dimensions [ 0 ] += tensor2 . Dimensions [ 0 ] ;
2935
30- var buffer = new float [ tensor1 . Length + tensor2 . Length ] . AsMemory ( ) ;
31- tensor1 . Buffer . CopyTo ( buffer [ ..( int ) tensor1 . Length ] ) ;
32- tensor2 . Buffer . CopyTo ( buffer [ ( int ) tensor1 . Length ..] ) ;
33- return new DenseTensor < float > ( buffer , dimensions ) ;
36+ var buffer = new DenseTensor < float > ( dimensions ) ;
37+ tensor1 . Buffer . CopyTo ( buffer . Buffer [ ..( int ) tensor1 . Length ] ) ;
38+ tensor2 . Buffer . CopyTo ( buffer . Buffer [ ( int ) tensor1 . Length ..] ) ;
39+ return buffer ;
40+ }
41+
42+
43+ private static DenseTensor < float > ConcatenateAxis1 ( DenseTensor < float > tensor1 , DenseTensor < float > tensor2 )
44+ {
45+ var dimensions = tensor1 . Dimensions . ToArray ( ) ;
46+ dimensions [ 1 ] += tensor2 . Dimensions [ 1 ] ;
47+ var concatenatedTensor = new DenseTensor < float > ( dimensions ) ;
48+
49+ // Copy data from the first tensor
50+ for ( int i = 0 ; i < dimensions [ 0 ] ; i ++ )
51+ for ( int j = 0 ; j < tensor1 . Dimensions [ 1 ] ; j ++ )
52+ concatenatedTensor [ i , j ] = tensor1 [ i , j ] ;
53+
54+ // Copy data from the second tensor
55+ for ( int i = 0 ; i < dimensions [ 0 ] ; i ++ )
56+ for ( int j = 0 ; j < tensor1 . Dimensions [ 1 ] ; j ++ )
57+ concatenatedTensor [ i , j + tensor1 . Dimensions [ 1 ] ] = tensor2 [ i , j ] ;
58+
59+ return concatenatedTensor ;
60+ }
61+
62+
63+ private static DenseTensor < float > ConcatenateAxis2 ( DenseTensor < float > tensor1 , DenseTensor < float > tensor2 )
64+ {
65+ var dimensions = tensor1 . Dimensions . ToArray ( ) ;
66+ dimensions [ 2 ] += tensor2 . Dimensions [ 2 ] ;
67+ var concatenatedTensor = new DenseTensor < float > ( dimensions ) ;
68+
69+ // Copy data from the first tensor
70+ for ( int i = 0 ; i < dimensions [ 0 ] ; i ++ )
71+ for ( int j = 0 ; j < dimensions [ 1 ] ; j ++ )
72+ for ( int k = 0 ; k < tensor1 . Dimensions [ 2 ] ; k ++ )
73+ concatenatedTensor [ i , j , k ] = tensor1 [ i , j , k ] ;
74+
75+ // Copy data from the second tensor
76+ for ( int i = 0 ; i < dimensions [ 0 ] ; i ++ )
77+ for ( int j = 0 ; j < dimensions [ 1 ] ; j ++ )
78+ for ( int k = 0 ; k < tensor2 . Dimensions [ 2 ] ; k ++ )
79+ concatenatedTensor [ i , j , k + tensor1 . Dimensions [ 2 ] ] = tensor2 [ i , j , k ] ;
80+
81+ return concatenatedTensor ;
3482 }
3583 }
3684}
0 commit comments