Skip to content

Commit 444cc42

Browse files
committed
node_def property in Operation #134
1 parent 43273b3 commit 444cc42

File tree

6 files changed

+81
-17
lines changed

6 files changed

+81
-17
lines changed

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
3636
var handle = Marshal.AllocHGlobal(size);
3737
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
3838
var consumers = new TF_Input[num];
39-
for(int i = 0; i < num; i++)
39+
for (int i = 0; i < num; i++)
4040
{
4141
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
4242
}
@@ -50,7 +50,7 @@ public unsafe Operation[] GetControlInputs()
5050
{
5151
var control_inputs = new Operation[NumControlInputs];
5252

53-
if(NumControlInputs > 0)
53+
if (NumControlInputs > 0)
5454
{
5555
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
5656
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
@@ -70,7 +70,7 @@ public unsafe Operation[] GetControlOutputs()
7070
{
7171
var control_outputs = new Operation[NumControlOutputs];
7272

73-
if(NumControlOutputs > 0)
73+
if (NumControlOutputs > 0)
7474
{
7575
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
7676
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
@@ -89,7 +89,7 @@ public Tensor[] outputs
8989
{
9090
get
9191
{
92-
if(_outputs == null)
92+
if (_outputs == null)
9393
{
9494
_outputs = new Tensor[NumOutputs];
9595

@@ -106,7 +106,7 @@ public InputList inputs
106106
{
107107
get
108108
{
109-
if(_inputs == null)
109+
if (_inputs == null)
110110
{
111111
var retval = new Tensor[NumInputs];
112112

@@ -124,6 +124,18 @@ public InputList inputs
124124
}
125125
}
126126

127+
private NodeDef _node_def;
128+
public NodeDef node_def
129+
{
130+
get
131+
{
132+
if(_node_def == null)
133+
_node_def = GetNodeDef();
134+
135+
return _node_def;
136+
}
137+
}
138+
127139
public Operation(IntPtr handle)
128140
{
129141
if (handle == IntPtr.Zero)
@@ -195,7 +207,7 @@ public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
195207
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
196208
}
197209

198-
public NodeDef GetNodeDef()
210+
private NodeDef GetNodeDef()
199211
{
200212
using (var s = new Status())
201213
using (var buffer = new Buffer())

src/TensorFlowNET.Core/Tensors/TF_DataType.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ public enum TF_DataType
3636
TF_UINT32 = 22,
3737
TF_UINT64 = 23,
3838

39+
DtFloatRef = 101, // DT_FLOAT_REF
3940
DtDoubleRef = 102, // DT_DOUBLE_REF
41+
DtInt32Ref = 103, // DT_INT32_REF
4042
}
4143
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ public Tensor(Operation op, int value_index, TF_DataType dtype)
162162
this.op = op;
163163
this.value_index = value_index;
164164
this._dtype = dtype;
165+
_id = ops.uid();
165166
}
166167

167168
public List<Operation> consumers()

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow
@@ -69,14 +70,15 @@ private void _init_from_args(object initial_value,
6970
{
7071

7172
}
73+
// Or get the initial value from a Tensor or Python object.
7274
else
7375
{
7476
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
75-
}
7677

77-
var shape = _initial_value.shape;
78-
dtype = _initial_value.dtype;
79-
_variable = gen_state_ops.variable_v2(shape, dtype, name);
78+
var shape = _initial_value.shape;
79+
dtype = _initial_value.dtype;
80+
_variable = gen_state_ops.variable_v2(shape, dtype, name);
81+
}
8082

8183
// Manually overrides the variable's shape with the initial value's.
8284
if (validate_shape)
@@ -87,8 +89,9 @@ private void _init_from_args(object initial_value,
8789
// If 'initial_value' makes use of other variables, make sure we don't
8890
// have an issue if these other variables aren't initialized first by
8991
// using their initialized_value() method.
92+
var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);
9093

91-
_initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
94+
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
9295

9396
if (!String.IsNullOrEmpty(caching_device))
9497
{
@@ -112,5 +115,51 @@ public Tensor _AsTensor()
112115
{
113116
return _snapshot;
114117
}
118+
119+
/// <summary>
120+
/// Attempt to guard against dependencies on uninitialized variables.
121+
/// </summary>
122+
/// <param name="initial_value"></param>
123+
private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value)
124+
{
125+
return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>());
126+
}
127+
128+
/// <summary>
129+
/// Replace dependencies on variables with their initialized values.
130+
/// </summary>
131+
/// <param name="tensor">A `Tensor`. The tensor to replace.</param>
132+
/// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
133+
/// <returns>A `Tensor` compatible with `tensor`.</returns>
134+
private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache)
135+
{
136+
var op = tensor.op;
137+
var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null;
138+
if(new_op == null)
139+
{
140+
new_op = _safe_initial_value_from_op(op, op_cache);
141+
op_cache[op.Name] = new_op;
142+
}
143+
return new_op.outputs[tensor.value_index];
144+
}
145+
146+
private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache)
147+
{
148+
var op_type = op.node_def.Op;
149+
switch (op_type)
150+
{
151+
case "IsVariableInitialized":
152+
case "VarIsInitializedOp":
153+
case "ReadVariableOp":
154+
return op;
155+
case "Variable":
156+
case "VariableV2":
157+
case "VarHandleOp":
158+
break;
159+
}
160+
161+
// Recursively build initializer expressions for inputs.
162+
return op;
163+
}
115164
}
116165
}

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name =
4242

4343
_execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name);
4444

45-
return new Tensor(_op, 0, dtype);
45+
return _result[0];
4646
}
4747

4848
/// <summary>

test/TensorFlowNET.UnitTest/GraphTest.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public void Graph()
130130
EXPECT_EQ(TF_Code.TF_OK, s.Code);
131131

132132
// Serialize to NodeDef.
133-
var node_def = neg.GetNodeDef();
133+
var node_def = neg.node_def;
134134

135135
// Validate NodeDef is what we expect.
136136
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
@@ -145,13 +145,13 @@ public void Graph()
145145
// Look up some nodes by name.
146146
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
147147
EXPECT_EQ(neg, neg2);
148-
var node_def2 = neg2.GetNodeDef();
148+
var node_def2 = neg2.node_def;
149149
EXPECT_EQ(node_def.ToString(), node_def2.ToString());
150150

151151
Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
152152
EXPECT_EQ(feed, feed2);
153-
node_def = feed.GetNodeDef();
154-
node_def2 = feed2.GetNodeDef();
153+
node_def = feed.node_def;
154+
node_def2 = feed2.node_def;
155155
EXPECT_EQ(node_def.ToString(), node_def2.ToString());
156156

157157
// Test iterating through the nodes of a graph.
@@ -186,7 +186,7 @@ public void Graph()
186186
}
187187
else
188188
{
189-
node_def = oper.GetNodeDef();
189+
node_def = oper.node_def;
190190
Assert.Fail($"Unexpected Node: {node_def.ToString()}");
191191
}
192192
}

0 commit comments

Comments
 (0)