Skip to content

Commit c45361c

Browse files
committed
added InputList
1 parent 1989988 commit c45361c

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

src/TensorFlowNET.Core/Eager/Execute.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ namespace Tensorflow.Eager
66
{
77
public class Execute
88
{
9-
public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
9+
public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
1010
{
11-
if (inputs == null)
12-
inputs = new Tensor[0];
13-
14-
pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name);
11+
pywrap_tfe_src.RecordGradient(op_name, inputs._inputs, attrs, results, name);
1512
}
1613
}
1714
}

src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<st
2525
}
2626
}
2727
if (!should_record) return;
28+
29+
var op_outputs = results;
30+
var op_inputs = inputs;
2831
}
2932
}
3033
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class InputList
8+
{
9+
public Tensor[] _inputs;
10+
11+
public InputList(Tensor[] inputs)
12+
{
13+
_inputs = inputs;
14+
}
15+
}
16+
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,23 @@ public Tensor[] outputs
101101
}
102102
}
103103

104-
private Tensor[] _inputs;
105-
public Tensor[] inputs
104+
private InputList _inputs;
105+
public InputList inputs
106106
{
107107
get
108108
{
109109
if(_inputs == null)
110110
{
111-
_inputs = new Tensor[NumInputs];
111+
var retval = new Tensor[NumInputs];
112112

113113
for (int i = 0; i < NumInputs; i++)
114114
{
115115
var tf_outpus = Input(i);
116116
var op = new Operation(tf_outpus.oper);
117-
_inputs[i] = op.outputs[tf_outpus.index];
117+
retval[i] = op.outputs[tf_outpus.index];
118118
}
119+
120+
_inputs = new InputList(retval);
119121
}
120122

121123
return _inputs;

0 commit comments

Comments
 (0)