@@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine
1010{
1111 public partial class Model
1212 {
13+ static Dictionary < string , List < ( string , NDArray ) > > weightsCache
14+ = new Dictionary < string , List < ( string , NDArray ) > > ( ) ;
15+
1316 public void load_weights ( string filepath , bool by_name = false , bool skip_mismatch = false , object options = null )
1417 {
18+ // Get from cache
19+ if ( weightsCache . ContainsKey ( filepath ) )
20+ {
21+ var filtered_layers = new List < ILayer > ( ) ;
22+ foreach ( var layer in Layers )
23+ {
24+ var weights = hdf5_format . _legacy_weights ( layer ) ;
25+ if ( weights . Count > 0 )
26+ filtered_layers . append ( layer ) ;
27+ }
28+
29+ var weight_value_tuples = new List < ( IVariableV1 , NDArray ) > ( ) ;
30+ filtered_layers . Select ( ( layer , i ) =>
31+ {
32+ var symbolic_weights = hdf5_format . _legacy_weights ( layer ) ;
33+ foreach ( var weight in symbolic_weights )
34+ {
35+ var weight_value = weightsCache [ filepath ] . First ( x => x . Item1 == weight . Name ) . Item2 ;
36+ weight_value_tuples . Add ( ( weight , weight_value ) ) ;
37+ }
38+ return layer ;
39+ } ) . ToList ( ) ;
40+
41+ keras . backend . batch_set_value ( weight_value_tuples ) ;
42+ return ;
43+ }
44+
1545 long fileId = Hdf5 . OpenFile ( filepath , true ) ;
1646 if ( fileId < 0 )
1747 {
@@ -29,8 +59,11 @@ public void load_weights(string filepath, bool by_name = false, bool skip_mismat
2959 throw new NotImplementedException ( "" ) ;
3060 else
3161 {
32- hdf5_format . load_weights_from_hdf5_group ( fileId , Layers ) ;
62+ var weight_value_tuples = hdf5_format . load_weights_from_hdf5_group ( fileId , Layers ) ;
3363 Hdf5 . CloseFile ( fileId ) ;
64+
65+ weightsCache [ filepath ] = weight_value_tuples . Select ( x => ( x . Item1 . Name , x . Item2 ) ) . ToList ( ) ;
66+ keras . backend . batch_set_value ( weight_value_tuples ) ;
3467 }
3568 }
3669
0 commit comments