Skip to content

Commit 1989988

Browse files
committed
fix Operation.inputs value is null. #117
1 parent 3265a38 commit 1989988

File tree

3 files changed

+55
-11
lines changed

3 files changed

+55
-11
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,17 @@ public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataTy
118118
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
119119
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
120120

121+
if (inputs == null)
122+
inputs = new List<Tensor>();
123+
124+
var input_ops = inputs.Select(x => x.op).ToArray();
125+
var control_inputs = _control_dependencies_for_inputs(input_ops);
126+
121127
var op = new Operation(node_def,
122128
this,
123129
inputs: inputs,
124130
output_types: dtypes,
125-
control_inputs: new object[] { },
131+
control_inputs: control_inputs,
126132
input_types: input_types,
127133
original_op: null,
128134
op_def: op_def);
@@ -131,6 +137,16 @@ public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataTy
131137
return op;
132138
}
133139

140+
/// <summary>
141+
/// For an op that takes `input_ops` as inputs, compute control inputs.
142+
/// </summary>
143+
/// <param name="input_ops">The data input ops for an op to be created.</param>
144+
/// <returns>A list of control inputs for the op to be created.</returns>
145+
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
146+
{
147+
return new Operation[0];
148+
}
149+
134150
private void _create_op_helper(Operation op, bool compute_device = true)
135151
{
136152

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,42 @@ public unsafe Operation[] GetControlOutputs()
8585
}
8686

8787
private Tensor[] _outputs;
88-
public Tensor[] outputs => _outputs;
89-
public Tensor[] inputs;
88+
public Tensor[] outputs
89+
{
90+
get
91+
{
92+
if(_outputs == null)
93+
{
94+
_outputs = new Tensor[NumOutputs];
95+
96+
for (int i = 0; i < NumOutputs; i++)
97+
_outputs[i] = new Tensor(this, i, OutputType(i));
98+
}
99+
100+
return _outputs;
101+
}
102+
}
103+
104+
private Tensor[] _inputs;
105+
public Tensor[] inputs
106+
{
107+
get
108+
{
109+
if(_inputs == null)
110+
{
111+
_inputs = new Tensor[NumInputs];
112+
113+
for (int i = 0; i < NumInputs; i++)
114+
{
115+
var tf_outpus = Input(i);
116+
var op = new Operation(tf_outpus.oper);
117+
_inputs[i] = op.outputs[tf_outpus.index];
118+
}
119+
}
120+
121+
return _inputs;
122+
}
123+
}
90124

91125
public Operation(IntPtr handle)
92126
{
@@ -115,14 +149,10 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
115149

116150
_handle = ops._create_c_op(g, node_def, inputs);
117151

118-
_outputs = new Tensor[NumOutputs];
119152
output_types = new TF_DataType[NumOutputs];
120153

121154
for (int i = 0; i < NumOutputs; i++)
122-
{
123155
output_types[i] = OutputType(i);
124-
_outputs[i] = new Tensor(this, i, output_types[i]);
125-
}
126156

127157
Graph._add_op(this);
128158
}
@@ -131,8 +161,6 @@ public object get_attr<T>(string name)
131161
{
132162
AttrValue x = null;
133163

134-
var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" };
135-
136164
using (var buf = new Buffer())
137165
{
138166
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name =
4040
_attrs["container"] = _op.get_attr<string>("container");
4141
_attrs["shared_name"] = _op.get_attr<string>("shared_name");
4242

43-
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
43+
_execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name);
4444

4545
return new Tensor(_op, 0, dtype);
4646
}
@@ -74,7 +74,7 @@ public static Tensor assign(Tensor tensor, Tensor value,
7474
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape");
7575
_attrs["use_locking"] = _op.get_attr<bool>("use_locking");
7676

77-
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
77+
_execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name);
7878

7979
return _result[0];
8080
}

0 commit comments

Comments
 (0)