88using static Tensorflow . Binding ;
99using static Tensorflow . KerasApi ;
1010using System . Linq ;
11-
11+ using Tensorflow . Util ;
1212namespace Tensorflow . Keras . Saving
1313{
14- public class fdf5_format
14+ public class hdf5_format
1515 {
16-
16+ private static int HDF5_OBJECT_HEADER_LIMIT = 64512 ;
1717 public static void load_model_from_hdf5 ( string filepath = "" , Dictionary < string , object > custom_objects = null , bool compile = false )
1818 {
1919 long root = Hdf5 . OpenFile ( filepath , true ) ;
@@ -79,10 +79,7 @@ public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Di
7979 {
8080
8181 }
82- public static void save_weights_to_hdf5_group ( long filepath = - 1 , Dictionary < string , object > custom_objects = null , bool compile = false )
83- {
8482
85- }
8683 public static void load_weights_from_hdf5_group ( long f , List < ILayer > layers )
8784 {
8885 string original_keras_version = "2.4.0" ;
@@ -136,9 +133,14 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
136133 var weight_values = new List < NDArray > ( ) ;
137134 long g = H5G . open ( f , name ) ;
138135 var weight_names = load_attributes_from_hdf5_group ( g , "weight_names" ) ;
136+ var get_Name = "" ;
139137 foreach ( var i_ in weight_names )
140138 {
141- ( bool success , Array result ) = Hdf5 . ReadDataset < float > ( g , i_ ) ;
139+ get_Name = i_ ;
140+ if ( get_Name . IndexOf ( "/" ) > 1 ) {
141+ get_Name = get_Name . Split ( '/' ) [ 1 ] ;
142+ }
143+ ( bool success , Array result ) = Hdf5 . ReadDataset < float > ( g , get_Name ) ;
142144 if ( success )
143145 weight_values . Add ( np . array ( result ) ) ;
144146 }
@@ -165,9 +167,171 @@ public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dict
165167 {
166168
167169 }
168- public static void save_attributes_to_hdf5_group ( long filepath = - 1 , Dictionary < string , object > custom_objects = null , bool compile = false )
170+ public static void save_weights_to_hdf5_group ( long f , List < ILayer > layers )
171+ {
172+ List < string > layerName = new List < string > ( ) ;
173+ foreach ( var layer in layers )
174+ {
175+ layerName . Add ( layer . Name ) ;
176+ }
177+ save_attributes_to_hdf5_group ( f , "layer_names" , layerName . ToArray ( ) ) ;
178+ Hdf5 . WriteAttribute ( f , "backend" , "tensorflow" ) ;
179+ Hdf5 . WriteAttribute ( f , "keras_version" , "2.5.0" ) ;
180+
181+ long g = 0 , crDataGroup = 0 ;
182+ List < IVariableV1 > weights = new List < IVariableV1 > ( ) ;
183+ //List<IVariableV1> weight_values = new List<IVariableV1>();
184+ List < string > weight_names = new List < string > ( ) ;
185+ foreach ( var layer in layers ) {
186+ weight_names = new List < string > ( ) ;
187+ g = Hdf5 . CreateOrOpenGroup ( f , Hdf5Utils . NormalizedName ( layer . Name ) ) ;
188+ weights = _legacy_weights ( layer ) ;
189+ //weight_values= keras.backend.batch_get_value(weights);
190+ foreach ( var weight in weights )
191+ {
192+ weight_names . Add ( weight . Name ) ;
193+ }
194+ save_attributes_to_hdf5_group ( g , "weight_names" , weight_names . ToArray ( ) ) ;
195+ Tensor tensor = null ;
196+ string get_Name = "" ;
197+ foreach ( var ( name , val ) in zip ( weight_names , weights ) ) {
198+ get_Name = name ;
199+ tensor = val . AsTensor ( ) ;
200+ if ( get_Name . IndexOf ( "/" ) > 1 )
201+ {
202+ get_Name = name . Split ( '/' ) [ 1 ] ;
203+ crDataGroup = Hdf5 . CreateOrOpenGroup ( g , Hdf5Utils . NormalizedName ( get_Name ) ) ;
204+ Hdf5 . CloseGroup ( crDataGroup ) ;
205+ }
206+ WriteDataset ( g , get_Name , tensor ) ;
207+ tensor = null ;
208+ }
209+ Hdf5 . CloseGroup ( g ) ;
210+ weight_names = null ;
211+ }
212+ weights = null ;
213+ // weight_values = null;
214+
215+
216+ }
217+ private static void save_attributes_to_hdf5_group ( long f , string name , Array data )
218+ {
219+ int num_chunks = 1 ;
220+
221+ var chunked_data = Split ( data , num_chunks ) ;
222+ int getSize = 0 ;
223+
224+ string getType = data . Length > 0 ? data . GetValue ( 0 ) . GetType ( ) . Name . ToLower ( ) : "string" ;
225+
226+ switch ( getType )
227+ {
228+ case "single" :
229+ getSize = sizeof ( float ) ;
230+ break ;
231+ case "double" :
232+ getSize = sizeof ( double ) ;
233+ break ;
234+ case "string" :
235+ getSize = - 1 ;
236+ break ;
237+ case "int32" :
238+ getSize = sizeof ( int ) ;
239+ break ;
240+ case "int64" :
241+ getSize = sizeof ( long ) ;
242+ break ;
243+ default :
244+ getSize = - 1 ;
245+ break ;
246+ }
247+ int getCount = chunked_data . Count ;
248+
249+ if ( getSize != - 1 ) {
250+ num_chunks = ( int ) Math . Ceiling ( ( double ) ( getCount * getSize ) / ( double ) HDF5_OBJECT_HEADER_LIMIT ) ;
251+ if ( num_chunks > 1 ) chunked_data = Split ( data , num_chunks ) ;
252+ }
253+
254+ if ( num_chunks > 1 )
255+ {
256+ foreach ( var ( chunk_id , chunk_data ) in enumerate ( chunked_data ) )
257+ {
258+
259+ WriteAttrs ( f , getType , $ "{ name } { chunk_id } ", chunk_data . ToArray ( ) ) ;
260+
261+ }
262+
263+ }
264+ else {
265+
266+ WriteAttrs ( f , getType , name , data ) ;
267+
268+ }
269+
270+ }
271+ private static void WriteDataset ( long f , string name , Tensor data )
272+ {
273+ switch ( data . dtype )
274+ {
275+ case TF_DataType . TF_FLOAT :
276+ Hdf5 . WriteDatasetFromArray < float > ( f , name , data . numpy ( ) . ToMuliDimArray < float > ( ) ) ;
277+ break ;
278+ case TF_DataType . TF_DOUBLE :
279+ Hdf5 . WriteDatasetFromArray < double > ( f , name , data . numpy ( ) . ToMuliDimArray < float > ( ) ) ;
280+ break ;
281+ case TF_DataType . TF_INT32 :
282+ Hdf5 . WriteDatasetFromArray < int > ( f , name , data . numpy ( ) . ToMuliDimArray < float > ( ) ) ;
283+ break ;
284+ case TF_DataType . TF_INT64 :
285+ Hdf5 . WriteDatasetFromArray < long > ( f , name , data . numpy ( ) . ToMuliDimArray < float > ( ) ) ;
286+ break ;
287+ default :
288+ Hdf5 . WriteDatasetFromArray < float > ( f , name , data . numpy ( ) . ToMuliDimArray < float > ( ) ) ;
289+ break ;
290+ }
291+ }
292+ private static void WriteAttrs ( long f , string typename , string name , Array data )
169293 {
294+ switch ( typename )
295+ {
296+ case "single" :
297+ Hdf5 . WriteAttributes < float > ( f , name , data ) ;
298+ break ;
299+ case "double" :
300+ Hdf5 . WriteAttributes < double > ( f , name , data ) ;
301+ break ;
302+ case "string" :
303+ Hdf5 . WriteAttributes < string > ( f , name , data ) ;
304+ break ;
305+ case "int32" :
306+ Hdf5 . WriteAttributes < int > ( f , name , data ) ;
307+ break ;
308+ case "int64" :
309+ Hdf5 . WriteAttributes < long > ( f , name , data ) ;
310+ break ;
311+ default :
312+ Hdf5 . WriteAttributes < string > ( f , name , data ) ;
313+ break ;
314+ }
315+ }
316+ private static List < List < object > > Split ( Array list , int chunkSize )
317+ {
318+ var splitList = new List < List < object > > ( ) ;
319+ var chunkCount = ( int ) Math . Ceiling ( ( double ) list . Length / ( double ) chunkSize ) ;
320+
321+ for ( int c = 0 ; c < chunkCount ; c ++ )
322+ {
323+ var skip = c * chunkSize ;
324+ var take = skip + chunkSize ;
325+ var chunk = new List < object > ( chunkSize ) ;
326+
327+ for ( int e = skip ; e < take && e < list . Length ; e ++ )
328+ {
329+ chunk . Add ( list . GetValue ( e ) ) ;
330+ }
331+ splitList . Add ( chunk ) ;
332+ }
170333
334+ return splitList ;
171335 }
172336 public static string [ ] load_attributes_from_hdf5_group ( long group , string name )
173337 {
0 commit comments