Skip to content

Commit 3265a38

Browse files
committed
add data type when determine which property should be taken out. #115
1 parent 1876cc9 commit 3265a38

File tree

6 files changed

+80
-9
lines changed

6 files changed

+80
-9
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Eager
6+
{
7+
public class Tape
8+
{
9+
public static bool IsDtypeTrainable(DataType dtype)
10+
{
11+
switch (dtype)
12+
{
13+
case DataType.DtHalf:
14+
case DataType.DtBfloat16:
15+
case DataType.DtFloat:
16+
case DataType.DtDouble:
17+
case DataType.DtComplex64:
18+
case DataType.DtComplex128:
19+
case DataType.DtResource:
20+
case DataType.DtVariant:
21+
return true;
22+
default:
23+
return false;
24+
}
25+
}
26+
}
27+
}

src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,19 @@ public class pywrap_tfe_src
1212
{
1313
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
1414
{
15-
15+
var input_ids = inputs.Select(x => x.Id).ToArray();
16+
var input_dtypes = inputs.Select(x => x.dtype).ToArray();
17+
18+
bool should_record = false;
19+
foreach (var input_dtype in input_dtypes)
20+
{
21+
if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum()))
22+
{
23+
should_record = true;
24+
break;
25+
}
26+
}
27+
if (!should_record) return;
1628
}
1729
}
1830
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Runtime.InteropServices;
45
using System.Text;
56

@@ -126,11 +127,11 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
126127
Graph._add_op(this);
127128
}
128129

129-
public object get_attr(string name)
130+
public object get_attr<T>(string name)
130131
{
131132
AttrValue x = null;
132133

133-
var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };
134+
var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" };
134135

135136
using (var buf = new Buffer())
136137
{
@@ -141,12 +142,21 @@ public object get_attr(string name)
141142

142143
switch (name)
143144
{
145+
case "T":
144146
case "dtype":
145147
return x.Type;
146148
case "shape":
147149
return x.Shape;
148150
default:
149-
throw new NotImplementedException($"{name}");
151+
switch (typeof(T).Name)
152+
{
153+
case "Boolean":
154+
return x.B;
155+
case "String":
156+
return x.S;
157+
default:
158+
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
159+
}
150160
}
151161
}
152162

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, st
2121
var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords);
2222
var _result = _op.outputs;
2323
var _inputs_flat = _op.inputs;
24-
var _attrs = new Dictionary<string, object>();
2524

26-
_attrs["dtype"] = _op.get_attr("dtype");
27-
_attrs["shape"] = _op.get_attr("shape");
25+
var _attrs = new Dictionary<string, object>();
26+
_attrs["dtype"] = _op.get_attr<DataType>("dtype");
27+
_attrs["shape"] = _op.get_attr<int[]>("shape");
2828

2929
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
30+
3031
return new Tensor(_op, 0, dtype);
3132
}
3233

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ public partial class Tensor : IDisposable
1616
{
1717
private readonly IntPtr _handle;
1818

19+
private int _id;
20+
public int Id => _id;
21+
1922
public Graph Graph => op.Graph;
2023
public Operation op { get; }
2124

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
using System;
33
using System.Collections.Generic;
44
using System.Text;
5+
using Tensorflow.Eager;
56

67
namespace Tensorflow
78
{
89
public class gen_state_ops
910
{
1011
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
12+
public static Execute _execute = new Execute();
1113

1214
/// <summary>
1315
/// Holds state in the form of a tensor that persists across steps.
@@ -32,6 +34,14 @@ public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name =
3234
var _result = _op.outputs;
3335
var _inputs_flat = _op.inputs;
3436

37+
var _attrs = new Dictionary<string, object>();
38+
_attrs["dtype"] = _op.get_attr<DataType>("dtype");
39+
_attrs["shape"] = _op.get_attr<int[]>("shape");
40+
_attrs["container"] = _op.get_attr<string>("container");
41+
_attrs["shared_name"] = _op.get_attr<string>("shared_name");
42+
43+
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
44+
3545
return new Tensor(_op, 0, dtype);
3646
}
3747

@@ -56,9 +66,17 @@ public static Tensor assign(Tensor tensor, Tensor value,
5666

5767
var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords);
5868

59-
var _result = _op.outputs[0];
69+
var _result = _op.outputs;
6070
var _inputs_flat = _op.inputs;
61-
return _result;
71+
72+
var _attrs = new Dictionary<string, object>();
73+
_attrs["T"] = _op.get_attr<DataType>("T");
74+
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape");
75+
_attrs["use_locking"] = _op.get_attr<bool>("use_locking");
76+
77+
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
78+
79+
return _result[0];
6280
}
6381
}
6482
}

0 commit comments

Comments
 (0)