@@ -67,7 +67,7 @@ public static NDArray MakeNdarray(TensorProto tensor)
6767
6868 T [ ] ExpandArrayToSize < T > ( IList < T > src )
6969 {
70- if ( src . Count == 0 )
70+ if ( src . Count == 0 )
7171 {
7272 return new T [ 0 ] ;
7373 }
@@ -77,7 +77,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
7777 var first_elem = src [ 0 ] ;
7878 var last_elem = src [ src . Count - 1 ] ;
7979 T [ ] res = new T [ num_elements ] ;
80- for ( long i = 0 ; i < num_elements ; i ++ )
80+ for ( long i = 0 ; i < num_elements ; i ++ )
8181 {
8282 if ( i < pre ) res [ i ] = first_elem ;
8383 else if ( i >= num_elements - after ) res [ i ] = last_elem ;
@@ -121,7 +121,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
121121 $ "https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.") ;
122122 }
123123
124- if ( values . size == 0 )
124+ if ( values . size == 0 )
125125 {
126126 return np . zeros ( shape , tensor_dtype ) ;
127127 }
@@ -135,23 +135,47 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135135 TF_DataType . TF_QINT32
136136 } ;
137137
138- private static TOut [ , ] ConvertArray2D < TIn , TOut > ( TIn [ , ] inputArray , Func < TIn , TOut > converter )
138+ private static Array ConvertArray < TOut > ( Array inputArray , Func < object , TOut > converter )
139139 {
140- var rows = inputArray . GetLength ( 0 ) ;
141- var cols = inputArray . GetLength ( 1 ) ;
142- var outputArray = new TOut [ rows , cols ] ;
140+ if ( inputArray == null )
141+ throw new ArgumentNullException ( nameof ( inputArray ) ) ;
143142
144- for ( var i = 0 ; i < rows ; i ++ )
143+ var elementType = typeof ( TOut ) ;
144+ var lengths = new int [ inputArray . Rank ] ;
145+ for ( var i = 0 ; i < inputArray . Rank ; i ++ )
145146 {
146- for ( var j = 0 ; j < cols ; j ++ )
147- {
148- outputArray [ i , j ] = converter ( inputArray [ i , j ] ) ;
149- }
147+ lengths [ i ] = inputArray . GetLength ( i ) ;
150148 }
151149
150+ var outputArray = Array . CreateInstance ( elementType , lengths ) ;
151+
152+ FillArray ( inputArray , outputArray , converter , new int [ inputArray . Rank ] , 0 ) ;
153+
152154 return outputArray ;
153155 }
154156
157+ private static void FillArray < TIn , TOut > ( Array inputArray , Array outputArray , Func < TIn , TOut > converter , int [ ] indices , int dimension )
158+ {
159+ if ( dimension == inputArray . Rank - 1 )
160+ {
161+ for ( int i = 0 ; i < inputArray . GetLength ( dimension ) ; i ++ )
162+ {
163+ indices [ dimension ] = i ;
164+ var inputValue = ( TIn ) inputArray . GetValue ( indices ) ;
165+ var convertedValue = converter ( inputValue ) ;
166+ outputArray . SetValue ( convertedValue , indices ) ;
167+ }
168+ }
169+ else
170+ {
171+ for ( int i = 0 ; i < inputArray . GetLength ( dimension ) ; i ++ )
172+ {
173+ indices [ dimension ] = i ;
174+ FillArray ( inputArray , outputArray , converter , indices , dimension + 1 ) ;
175+ }
176+ }
177+ }
178+
155179 /// <summary>
156180 /// Create a TensorProto, invoked in graph mode
157181 /// </summary>
@@ -171,24 +195,30 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
171195 var origin_dtype = values . GetDataType ( ) ;
172196 if ( dtype == TF_DataType . DtInvalid )
173197 dtype = origin_dtype ;
174- else if ( origin_dtype != dtype )
198+ else if ( origin_dtype != dtype )
175199 {
176200 var new_system_dtype = dtype . as_system_dtype ( ) ;
177-
178- values = values switch
201+
202+ if ( dtype != TF_DataType . TF_STRING && dtype != TF_DataType . TF_VARIANT && dtype != TF_DataType . TF_RESOURCE )
203+ {
204+ if ( values is Array arrayValues )
205+ {
206+ values = dtype switch
207+ {
208+ TF_DataType . TF_INT32 => ConvertArray ( arrayValues , Convert . ToInt32 ) ,
209+ TF_DataType . TF_FLOAT => ConvertArray ( arrayValues , Convert . ToSingle ) ,
210+ TF_DataType . TF_DOUBLE => ConvertArray ( arrayValues , Convert . ToDouble ) ,
211+ _ => values ,
212+ } ;
213+ } else
214+ {
215+ values = Convert . ChangeType ( values , new_system_dtype ) ;
216+ }
217+
218+ } else
179219 {
180- long [ ] longValues when dtype == TF_DataType . TF_INT32 => longValues . Select ( x => ( int ) x ) . ToArray ( ) ,
181- long [ ] longValues => values ,
182- float [ ] floatValues when dtype == TF_DataType . TF_DOUBLE => floatValues . Select ( x => ( double ) x ) . ToArray ( ) ,
183- float [ ] floatValues => values ,
184- float [ , ] float2DValues when dtype == TF_DataType . TF_DOUBLE => ConvertArray2D ( float2DValues , Convert . ToDouble ) ,
185- float [ , ] float2DValues => values ,
186- double [ ] doubleValues when dtype == TF_DataType . TF_FLOAT => doubleValues . Select ( x => ( float ) x ) . ToArray ( ) ,
187- double [ ] doubleValues => values ,
188- double [ , ] double2DValues when dtype == TF_DataType . TF_FLOAT => ConvertArray2D ( double2DValues , Convert . ToSingle ) ,
189- double [ , ] double2DValues => values ,
190- _ => Convert . ChangeType ( values , new_system_dtype ) ,
191- } ;
220+
221+ }
192222 dtype = values . GetDataType ( ) ;
193223 }
194224
@@ -306,7 +336,7 @@ bool hasattr(Graph property, string attr)
306336
307337 if ( tensor is EagerTensor eagerTensor )
308338 {
309- if ( tensor . dtype == tf . int64 )
339+ if ( tensor . dtype == tf . int64 )
310340 return new Shape ( tensor . ToArray < long > ( ) ) ;
311341 else
312342 return new Shape ( tensor . ToArray < int > ( ) ) ;
@@ -481,7 +511,7 @@ bool hasattr(Graph property, string attr)
481511 var d_ = new int [ value . size ] ;
482512 foreach ( var ( index , d ) in enumerate ( value . ToArray < int > ( ) ) )
483513 d_ [ index ] = d >= 0 ? d : - 1 ;
484-
514+
485515 ret = ret . merge_with ( new Shape ( d_ ) ) ;
486516 }
487517 return ret ;
0 commit comments