Skip to content

Commit 2fc45b9

Browse files
committed
1 parent 8b32de7 commit 2fc45b9

File tree

17 files changed

+149
-22
lines changed

17 files changed

+149
-22
lines changed

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace Tensorflow
2727
/// string => IntPtr c_api.StringPiece(IntPtr)
2828
/// unsigned char => byte
2929
/// </summary>
30-
public static partial class c_api
30+
public partial class c_api
3131
{
3232
public const string TensorFlowLibName = "tensorflow";
3333

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public static partial class c_api
8+
public partial class c_api
99
{
1010
/// <summary>
1111
/// Fills in `value` with the value of the attribute `attr_name`. `value` must

src/TensorFlowNET.Core/Buffers/c_api.buffer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public static partial class c_api
8+
public partial class c_api
99
{
1010
[DllImport(TensorFlowLibName)]
1111
public static extern void TF_DeleteBuffer(IntPtr buffer);

src/TensorFlowNET.Core/Functions/c_api.function.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public static partial class c_api
8+
public partial class c_api
99
{
1010
/// <summary>
1111
/// Write out a serialized representation of `func` (as a FunctionDef protocol

src/TensorFlowNET.Core/Gradients/c_api.gradient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public static partial class c_api
8+
public partial class c_api
99
{
1010
/// <summary>
1111
/// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,76 @@ public static void _GradientsHelper(object ys,
5353
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all))
5454
{
5555
grad_scope = namescope;
56+
// Get a uid for this call to gradients that can be used to help
57+
// cluster ops for compilation.
58+
var gradient_uid = ops.get_default_graph().unique_name("uid");
5659

60+
var to_ops = ys1.Select(x => x.op).ToList();
61+
var from_ops = xs1.Select(x => x.op).ToList();
62+
var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList();
63+
_PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs1);
5764
}
5865
}
5966

67+
/// <summary>
68+
///
69+
/// </summary>
70+
/// <param name="grad_ys"></param>
71+
/// <param name="ys"></param>
72+
/// <param name="colocate_gradients_with_ops"></param>
73+
/// <param name="gradient_uid"></param>
74+
private void _DefaultGradYs(List<Tensor> grad_ys, List<Tensor> ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__")
75+
{
76+
77+
}
78+
79+
/// <summary>
80+
/// Initialize the pending count for ops between two lists of Operations.
81+
/// 'pending_count[op]' indicates the number of backprop inputs
82+
/// to this operation.
83+
/// </summary>
84+
/// <param name="to_ops"></param>
85+
/// <param name="from_ops"></param>
86+
/// <param name="colocate_gradients_with_ops"></param>
87+
/// <param name="func_graphs"></param>
88+
/// <param name="xs"></param>
89+
private static void _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, List<Tensor> xs)
90+
{
91+
List<Operation> reached_ops = new List<Operation>();
92+
_MarkReachedOps(from_ops, reached_ops, func_graphs);
93+
}
94+
95+
/// <summary>
96+
/// Mark all ops reached from "from_ops"
97+
/// </summary>
98+
/// <param name="from_ops"></param>
99+
/// <param name="reached_ops"></param>
100+
/// <param name="func_graphs"></param>
101+
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs)
102+
{
103+
foreach(var op in from_ops)
104+
{
105+
reached_ops.Add(op);
106+
foreach(var output in op.outputs)
107+
{
108+
reached_ops.AddRange(_Consumers(output, func_graphs));
109+
}
110+
}
111+
112+
reached_ops.Reverse();
113+
}
114+
115+
/// <summary>
116+
/// Returns the consumers of t, crossing closure boundaries where necessary.
117+
/// </summary>
118+
/// <param name="t"></param>
119+
/// <param name="func_graphs"></param>
120+
private static List<Operation> _Consumers(Tensor t, List<object> func_graphs)
121+
{
122+
var consumers = t.consumers();
123+
return consumers;
124+
}
125+
60126
private static List<Tensor> _AsList(object ys)
61127
{
62128
List<Tensor> ret = null;

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public static partial class c_api
8+
public partial class c_api
99
{
1010
/// <summary>
1111
/// Destroy an options object. Graph will be deleted once no more

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ public class Operation
3232
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
3333
{
3434
int size = Marshal.SizeOf<TF_Input>();
35-
var handle = (TF_Input*)Marshal.AllocHGlobal(size);
35+
var handle = Marshal.AllocHGlobal(size);
3636
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
3737
var consumers = new TF_Input[num];
3838
for(int i = 0; i < num; i++)
3939
{
40-
consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index);
40+
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
4141
}
4242

4343
return consumers;
@@ -161,6 +161,11 @@ public NodeDef GetNodeDef()
161161
}
162162
}
163163

164+
public override string ToString()
165+
{
166+
return $"'{Name}' type={OpType}";
167+
}
168+
164169
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
165170
public static implicit operator IntPtr(Operation op) => op._handle;
166171

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public static partial class c_api
8+
public partial class c_api
99
{
1010
/// <summary>
1111
/// Request that `desc` be co-located on the device where `op`
@@ -154,12 +154,15 @@ public static partial class c_api
154154
/// an operation. Returns the number of output consumers (should match
155155
/// TF_OperationOutputNumConsumers(oper_out)).
156156
/// </summary>
157-
/// <param name="oper_out"></param>
158-
/// <param name="consumers"></param>
159-
/// <param name="max_consumers"></param>
157+
/// <param name="oper_out">TF_Output</param>
158+
/// <param name="consumers">TF_Input*</param>
159+
/// <param name="max_consumers">int</param>
160160
/// <returns></returns>
161161
[DllImport(TensorFlowLibName)]
162-
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers);
162+
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers);
163+
164+
[DllImport(TensorFlowLibName)]
165+
public static extern int TF_OperationOutputConsumers(TF_Output oper_out);
163166

164167
[DllImport(TensorFlowLibName)]
165168
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,20 @@ public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
2626

2727
return new Tensor(_op, 0, dtype);
2828
}
29+
30+
/// <summary>
31+
/// Return a tensor with the same shape and contents as the input tensor or value.
32+
/// </summary>
33+
/// <param name="input"></param>
34+
/// <param name="name"></param>
35+
public static Tensor identity(Tensor input, string name = "")
36+
{
37+
var keywords = new Dictionary<string, object>();
38+
keywords.Add("input", input);
39+
40+
var _op = _op_def_lib._apply_op_helper("Identity", name, keywords);
41+
42+
return _op.outputs[0];
43+
}
2944
}
3045
}

0 commit comments

Comments
 (0)