1- /*****************************************************************************
1+ /*****************************************************************************
22 Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
33
44 Licensed under the Apache License, Version 2.0 (the "License");
@@ -135,6 +135,23 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135135 TF_DataType . TF_QINT32
136136 } ;
137137
138+ private static TOut [ , ] ConvertArray2D < TIn , TOut > ( TIn [ , ] inputArray , Func < TIn , TOut > converter )
139+ {
140+ var rows = inputArray . GetLength ( 0 ) ;
141+ var cols = inputArray . GetLength ( 1 ) ;
142+ var outputArray = new TOut [ rows , cols ] ;
143+
144+ for ( var i = 0 ; i < rows ; i ++ )
145+ {
146+ for ( var j = 0 ; j < cols ; j ++ )
147+ {
148+ outputArray [ i , j ] = converter ( inputArray [ i , j ] ) ;
149+ }
150+ }
151+
152+ return outputArray ;
153+ }
154+
138155 /// <summary>
139156 /// Create a TensorProto, invoked in graph mode
140157 /// </summary>
@@ -157,19 +174,16 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
157174 else if ( origin_dtype != dtype )
158175 {
159176 var new_system_dtype = dtype . as_system_dtype ( ) ;
160- if ( values is long [ ] long_values )
161- {
162- if ( dtype == TF_DataType . TF_INT32 )
163- values = long_values . Select ( x => ( int ) Convert . ChangeType ( x , new_system_dtype ) ) . ToArray ( ) ;
164- }
165- else if ( values is double [ ] double_values )
177+
178+ values = values switch
166179 {
167- if ( dtype == TF_DataType . TF_FLOAT )
168- values = double_values . Select ( x => ( float ) Convert . ChangeType ( x , new_system_dtype ) ) . ToArray ( ) ;
169- }
170- else
171- values = Convert . ChangeType ( values , new_system_dtype ) ;
172-
180+ long [ ] longValues when dtype == TF_DataType . TF_INT32 => longValues . Select ( x => ( int ) x ) . ToArray ( ) ,
181+ float [ ] floatValues when dtype == TF_DataType . TF_DOUBLE => floatValues . Select ( x => ( double ) x ) . ToArray ( ) ,
182+ float [ , ] float2DValues when dtype == TF_DataType . TF_DOUBLE => ConvertArray2D ( float2DValues , Convert . ToDouble ) ,
183+ double [ ] doubleValues when dtype == TF_DataType . TF_FLOAT => doubleValues . Select ( x => ( float ) x ) . ToArray ( ) ,
184+ double [ , ] double2DValues when dtype == TF_DataType . TF_DOUBLE => ConvertArray2D ( double2DValues , Convert . ToSingle ) ,
185+ _ => Convert . ChangeType ( values , new_system_dtype ) ,
186+ } ;
173187 dtype = values . GetDataType ( ) ;
174188 }
175189
0 commit comments