1+ using Tensornet . Native ;
2+ using Tensornet . Exceptions ;
3+ using Tensornet . Native . Param ;
4+
5+ namespace Tensornet {
6+ public static class MaxExtension {
7+ public static Tensor < T > Max < T > ( this Tensor < T > src , int [ ] axes , bool keepDims = false ) where T : struct , IEquatable < T > , IConvertible
8+ {
9+ Tensor < T > res = new Tensor < T > ( DeduceLayout ( src . TLayout , axes ) ) ;
10+ res . TLayout . InitContiguousLayout ( ) ;
11+ bool [ ] boolDims = new bool [ src . TLayout . NDim ] ;
12+ var span = boolDims . AsSpan ( ) ;
13+ span . Fill ( false ) ;
14+ foreach ( var axis in axes ) {
15+ span [ axis ] = true ;
16+ }
17+ MaxInternal ( src , res , boolDims , keepDims ) ;
18+ return res ;
19+ }
20+ public static Tensor < T > Max < T > ( this Tensor < T > src , int axis , bool keepDims = false ) where T : struct , IEquatable < T > , IConvertible
21+ {
22+ Tensor < T > res = new Tensor < T > ( DeduceLayout ( src . TLayout , axis ) ) ;
23+ res . TLayout . InitContiguousLayout ( ) ;
24+ bool [ ] boolDims = new bool [ src . TLayout . NDim ] ;
25+ var span = boolDims . AsSpan ( ) ;
26+ span . Fill ( false ) ;
27+ span [ axis ] = true ;
28+ MaxInternal ( src , res , boolDims , keepDims ) ;
29+ return res ;
30+ }
31+ public static Tensor < T > Max < T > ( this Tensor < T > src , bool keepDims = false ) where T : struct , IEquatable < T > , IConvertible
32+ {
33+ Tensor < T > res = new Tensor < T > ( DeduceLayout ( src . TLayout ) ) ;
34+ res . TLayout . InitContiguousLayout ( ) ;
35+ bool [ ] boolDims = new bool [ src . TLayout . NDim ] ;
36+ boolDims . AsSpan ( ) . Fill ( true ) ;
37+ MaxInternal ( src , res , boolDims , keepDims ) ;
38+ return res ;
39+ }
40+ private unsafe static void MaxInternal < T > ( Tensor < T > src , Tensor < T > dst , bool [ ] dims , bool keepDims ) where T : struct , IEquatable < T > , IConvertible {
41+ fixed( bool * ptr = dims ) {
42+ MaxParam p = new MaxParam ( ) { dims = new IntPtr ( ptr ) } ;
43+ IntPtr status = NativeExecutor . Execute ( NativeApi . Max , src . TMemory , dst . TMemory , src . TLayout , dst . TLayout , new IntPtr ( & p ) , Tensor < T > . Provider ) ;
44+ NativeStatus . AssertOK ( status ) ;
45+ }
46+ if ( ! keepDims ) {
47+ dst . TLayout . RemoveAllDanglingAxisInplace ( ) ;
48+ }
49+ }
50+ private static TensorLayout DeduceLayout ( TensorLayout src , int [ ] axes ) {
51+ var res = new TensorLayout ( src , true ) ;
52+ foreach ( var dim in axes ) {
53+ res . Shape [ dim ] = 1 ;
54+ }
55+ return res ;
56+ }
57+ private static TensorLayout DeduceLayout ( TensorLayout src , int axis ) {
58+ var res = new TensorLayout ( src , true ) ;
59+ res . Shape [ axis ] = 1 ;
60+ return res ;
61+ }
62+ private static TensorLayout DeduceLayout ( TensorLayout src ) {
63+ var res = new TensorLayout ( src , true ) ;
64+ res . Shape . AsSpan ( ) . Fill ( 1 ) ;
65+ return res ;
66+ }
67+ }
68+
69+ public static partial class Tensor {
70+ /// <Summary>
71+ /// Get the maximum elements of the tensor.
72+ /// </Summary>
73+ /// <typeparam name="T"></typeparam>
74+ /// <param name="src"> The tensor to get maximum elements. </param>
75+ /// <param name="axes"> The axes to execute. </param>
76+ /// <param name="keepDims"> Whether to keep the dims after the Max. If false, the NDim of the result may be different with the input. </param>
77+ /// <returns></returns>
78+ public static Tensor < T > Max < T > ( Tensor < T > src , int [ ] axes , bool keepDims = false ) where T : struct , IEquatable < T > , IConvertible {
79+ return src . Max ( axes , keepDims ) ;
80+ }
81+ /// <Summary>
82+ /// Get the maximum elements of the tensor.
83+ /// </Summary>
84+ /// <typeparam name="T"></typeparam>
85+ /// <param name="src"> The tensor to get maximum elements. </param>
86+ /// <param name="axis"> The axis to execute. </param>
87+ /// <param name="keepDims"> Whether to keep the dims after the Max. If false, the NDim of the result may be different with the input. </param>
88+ /// <returns></returns>
89+ public static Tensor < T > Max < T > ( Tensor < T > src , int axis , bool keepDims = false ) where T : struct , IEquatable < T > , IConvertible {
90+ return src . Max ( axis , keepDims ) ;
91+ }
92+ /// <Summary>
93+ /// Get the maximum elements of the tensor.
94+ /// </Summary>
95+ /// <typeparam name="T"></typeparam>
96+ /// <param name="src"> The tensor to get maximum elements. </param>
97+ /// <param name="keepDims"> Whether to keep the dims after the Max. If false, the NDim of the result may be different with the input. </param>
98+ /// <returns></returns>
99+ public static Tensor < T > Max < T > ( Tensor < T > src , bool keepDims = false ) where T : struct , IEquatable < T > , IConvertible {
100+ return src . Max ( keepDims ) ;
101+ }
102+ }
103+ }
0 commit comments