Skip to content

Commit 8b53eb3

Browse files
SanftMonsterOceania2018
authored andcommitted
fix: partially fix the error when saving model after loading.
1 parent d9988d7 commit 8b53eb3

File tree

6 files changed

+66
-12
lines changed

6 files changed

+66
-12
lines changed

src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Diagnostics;
44
using System.IO;
55
using System.Linq;
6+
using Tensorflow.Functions;
67
using Tensorflow.Train;
78
using Tensorflow.Training;
89
using pbc = global::Google.Protobuf.Collections;
@@ -13,7 +14,7 @@ public static class CheckPointUtils
1314
{
1415
private static string _ESCAPE_CHAR = ".";
1516
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
16-
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
17+
IDictionary<Trackable, pbc::RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
1718
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
1819
{
1920
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();

src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,14 @@ private void initialize_nodes_and_concrete_functions()
9393
//
9494
// }
9595

96-
foreach (var obj in _nodes)
97-
{
98-
if (obj is ConcreteFunction)
99-
{
100-
_concrete_functions.Add((ConcreteFunction)obj);
101-
}
102-
}
96+
//_concrete_functions = new();
97+
//foreach (var obj in _nodes)
98+
//{
99+
// if (obj is ConcreteFunction)
100+
// {
101+
// _concrete_functions.Add((ConcreteFunction)obj);
102+
// }
103+
//}
103104
}
104105

105106
public List<ConcreteFunction> get_concrete_resource_initializers()
@@ -225,8 +226,8 @@ private static void write_object_proto(Trackable obj, SavedObject proto,
225226
}
226227
else if (obj is ConcreteFunction)
227228
{
228-
// TODO: complete it.
229-
throw new NotImplementedException();
229+
// TODO(Rinne): complete it.
230+
// throw new NotImplementedException();
230231
}
231232
// skip the process of type `_CapturedTensor` and `CapturableResource`.
232233
else

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@ public class BaseResourceVariable : DisposableTrackableObject
1717
{
1818
protected string _name;
1919
public virtual string Name => _handle_name;
20-
public virtual string SharedName => _name;
20+
public virtual string SharedName
21+
{
22+
get
23+
{
24+
// TODO(Rinne): optimize the implementation with refactor of variable.
25+
return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1);
26+
}
27+
}
2128
protected TF_DataType _dtype;
2229
public TF_DataType dtype => _dtype;
2330
protected string _handle_name;

src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ public void finalize_objects()
152152
_reconstruct_all_models();
153153
}
154154

155+
/// <summary>
156+
/// Removes tracked references that are only used when loading the model.
157+
/// Now that the node object has been fully loaded, and the checkpoint has
158+
/// been restored, the object no longer needs to track objects added from
159+
/// SerializedAttributes. (Note that saving a training checkpoint still
160+
/// functions correctly, because layers and variables are tracked
161+
/// separately by the Layer object.)
162+
/// </summary>
163+
public void del_tracking()
164+
{
165+
foreach(var (node, _) in loaded_nodes.Values)
166+
{
167+
if(node is not Layer layer)
168+
{
169+
continue;
170+
}
171+
foreach(var name in PUBLIC_ATTRIBUTES.Keys)
172+
{
173+
layer._delete_tracking(name);
174+
}
175+
if(node is Functional functional)
176+
{
177+
foreach(var name in functional.UnconditionalDependencyNames.Keys)
178+
{
179+
if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success)
180+
{
181+
functional._delete_tracking(name);
182+
}
183+
}
184+
}
185+
}
186+
}
187+
155188
private void _reconstruct_all_models()
156189
{
157190
HashSet<int> all_initialized_models = new();

src/TensorFlowNET.Keras/Saving/SavedModel/load.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private static Trackable load(string path, bool compile = true, LoadOptions? opt
7777
var loaded = Loader.load_partial(path, nodes_to_load, options);
7878

7979
keras_loader.finalize_objects();
80-
// keras_loader.del_tracking();
80+
keras_loader.del_tracking();
8181

8282
var model = loaded["root"];
8383

test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,5 +196,17 @@ public void AlexnetFromSequential()
196196
// )
197197
#endregion
198198
}
199+
200+
[TestMethod]
201+
public void SaveAfterLoad()
202+
{
203+
var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
204+
model.summary();
205+
206+
model.save("Assets/saved_auto_compile_after_loading");
207+
208+
//model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading");
209+
//model.summary();
210+
}
199211
}
200212
}

0 commit comments

Comments
 (0)