Skip to content

Commit 8b32de7

Browse files
committed
Fixed #111
1 parent 918f757 commit 8b32de7

File tree

12 files changed

+166
-93
lines changed

12 files changed

+166
-93
lines changed

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public static void _GradientsHelper(object ys,
3030
if (src_graph == null)
3131
src_graph = ops.get_default_graph();
3232

33+
// If src_graph is a _FuncGraph (i.e. a function body), gather it and all
34+
// ancestor graphs. This is necessary for correctly handling captured values.
35+
var curr_graph = src_graph;
36+
3337
var ys1 = _AsList(ys);
3438
var xs1 = _AsList(xs);
3539
List<Tensor> grad_ys1 = null;
@@ -47,7 +51,10 @@ public static void _GradientsHelper(object ys,
4751

4852
string grad_scope = "";
4953
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all))
54+
{
5055
grad_scope = namescope;
56+
57+
}
5158
}
5259

5360
private static List<Tensor> _AsList(object ys)

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ public string name_scope(string name)
173173

174174
string new_stack = "";
175175

176-
177176
if (name.EndsWith("/"))
178177
new_stack = ops._name_from_scope_name(name);
179178
else

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 84 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
1515
var g = ops.get_default_graph();
1616
var op_def = g.GetOpDef(op_type_name);
1717

18+
// Default name if not specified.
1819
if (String.IsNullOrEmpty(name))
19-
{
2020
name = op_type_name;
21-
}
2221

23-
string scope = "";
24-
using (var namescope = new ops.name_scope<object>(name))
25-
scope = namescope;
22+
// Check for deprecation
23+
if(op_def.Deprecation != null && op_def.Deprecation.Version > 0)
24+
{
25+
26+
}
2627

2728
var default_type_attr_map = new Dictionary<string, object>();
2829
foreach (var attr_def in op_def.Attr)
@@ -39,101 +40,107 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
3940
var inputs = new List<Tensor>();
4041
var input_types = new List<TF_DataType>();
4142

42-
// Perform input type inference
43-
foreach (var input_arg in op_def.InputArg)
43+
string scope = "";
44+
using (var namescope = new ops.name_scope<object>(name))
4445
{
45-
var input_name = input_arg.Name;
46-
if (keywords[input_name] is double int_value)
47-
{
48-
keywords[input_name] = constant_op.Constant(int_value, input_name);
49-
}
46+
scope = namescope;
5047

51-
if (keywords[input_name] is Tensor value)
48+
// Perform input type inference
49+
foreach (var input_arg in op_def.InputArg)
5250
{
53-
if (keywords.ContainsKey(input_name))
51+
var input_name = input_arg.Name;
52+
if (keywords[input_name] is double int_value)
5453
{
55-
inputs.Add(value);
54+
keywords[input_name] = constant_op.Constant(int_value, input_name);
5655
}
5756

58-
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
57+
if (keywords[input_name] is Tensor value)
5958
{
60-
attrs[input_arg.TypeAttr] = value.dtype;
59+
if (keywords.ContainsKey(input_name))
60+
{
61+
inputs.Add(value);
62+
}
63+
64+
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
65+
{
66+
attrs[input_arg.TypeAttr] = value.dtype;
67+
}
68+
69+
if (input_arg.IsRef)
70+
{
71+
72+
}
73+
else
74+
{
75+
input_types.Add(value.dtype);
76+
}
6177
}
78+
}
6279

63-
if (input_arg.IsRef)
64-
{
65-
66-
}
67-
else
80+
// Process remaining attrs
81+
foreach (var attr in op_def.Attr)
82+
{
83+
if (keywords.ContainsKey(attr.Name))
6884
{
69-
input_types.Add(value.dtype);
85+
attrs[attr.Name] = keywords[attr.Name];
7086
}
7187
}
72-
}
7388

74-
// Process remaining attrs
75-
foreach (var attr in op_def.Attr)
76-
{
77-
if (keywords.ContainsKey(attr.Name))
89+
// Convert attr values to AttrValue protos.
90+
var attr_protos = new Dictionary<string, AttrValue>();
91+
foreach (var attr_def in op_def.Attr)
7892
{
79-
attrs[attr.Name] = keywords[attr.Name];
80-
}
81-
}
93+
var key = attr_def.Name;
94+
var value = attrs[key];
95+
var attr_value = new AttrValue();
8296

83-
// Convert attr values to AttrValue protos.
84-
var attr_protos = new Dictionary<string, AttrValue>();
85-
foreach (var attr_def in op_def.Attr)
86-
{
87-
var key = attr_def.Name;
88-
var value = attrs[key];
89-
var attr_value = new AttrValue();
90-
91-
switch (attr_def.Type)
92-
{
93-
case "string":
94-
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
95-
break;
96-
case "type":
97-
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
98-
break;
99-
case "bool":
100-
attr_value.B = (bool)value;
101-
break;
102-
case "shape":
103-
attr_value.Shape = value == null ?
104-
attr_def.DefaultValue.Shape :
105-
tensor_util.as_shape((long[])value);
106-
break;
107-
default:
108-
throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
109-
}
97+
switch (attr_def.Type)
98+
{
99+
case "string":
100+
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
101+
break;
102+
case "type":
103+
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
104+
break;
105+
case "bool":
106+
attr_value.B = (bool)value;
107+
break;
108+
case "shape":
109+
attr_value.Shape = value == null ?
110+
attr_def.DefaultValue.Shape :
111+
tensor_util.as_shape((long[])value);
112+
break;
113+
default:
114+
throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
115+
}
110116

111-
attr_protos[key] = attr_value;
112-
}
117+
attr_protos[key] = attr_value;
118+
}
113119

114-
// Determine output types (possibly using attrs)
115-
var output_types = new List<TF_DataType>();
120+
// Determine output types (possibly using attrs)
121+
var output_types = new List<TF_DataType>();
116122

117-
foreach (var arg in op_def.OutputArg)
118-
{
119-
if (!String.IsNullOrEmpty(arg.NumberAttr))
123+
foreach (var arg in op_def.OutputArg)
120124
{
125+
if (!String.IsNullOrEmpty(arg.NumberAttr))
126+
{
121127

128+
}
129+
else if (!String.IsNullOrEmpty(arg.TypeAttr))
130+
{
131+
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
132+
}
122133
}
123-
else if (!String.IsNullOrEmpty(arg.TypeAttr))
124-
{
125-
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
126-
}
127-
}
128134

129-
// Add Op to graph
130-
var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
131-
name: scope,
132-
input_types: input_types.ToArray(),
133-
attrs: attr_protos,
134-
op_def: op_def);
135+
// Add Op to graph
136+
var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
137+
name: scope,
138+
input_types: input_types.ToArray(),
139+
attrs: attr_protos,
140+
op_def: op_def);
135141

136-
return op;
142+
return op;
143+
}
137144
}
138145

139146
public DataType _MakeType(TF_DataType v, AttrDef attr_def)

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
<TargetFramework>netstandard2.0</TargetFramework>
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
7-
<Version>0.0.2</Version>
7+
<Version>0.0.3</Version>
88
<Authors>Haiping Chen</Authors>
9-
<Company>SciSharp.org</Company>
9+
<Company>SciSharp STACK</Company>
1010
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
1111
<Copyright>Apache 2.0</Copyright>
1212
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
@@ -16,7 +16,7 @@
1616
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags>
1717
<Description>Google's TensorFlow binding in .NET Standard.
1818
Docs: https://tensorflownet.readthedocs.io</Description>
19-
<AssemblyVersion>0.0.2.0</AssemblyVersion>
19+
<AssemblyVersion>0.0.3.0</AssemblyVersion>
2020
<PackageReleaseNotes>API updated</PackageReleaseNotes>
2121
<LangVersion>7.2</LangVersion>
2222
</PropertyGroup>

src/TensorFlowNET.Core/Tensors/TF_DataType.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
namespace Tensorflow
66
{
7+
/// <summary>
8+
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
9+
/// The enum values here are identical to corresponding values in types.proto.
10+
/// </summary>
711
public enum TF_DataType
812
{
913
DtInvalid = 0,
@@ -30,6 +34,8 @@ public enum TF_DataType
3034
TF_RESOURCE = 20,
3135
TF_VARIANT = 21,
3236
TF_UINT32 = 22,
33-
TF_UINT64 = 23
37+
TF_UINT64 = 23,
38+
39+
DtDoubleRef = 102, // DT_DOUBLE_REF
3440
}
3541
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ public partial class Tensor : IDisposable
1919
public Graph Graph => op.Graph;
2020
public Operation op { get; }
2121

22-
public string name;
22+
/// <summary>
23+
/// The string name of this tensor.
24+
/// </summary>
25+
public string name => $"{(op == null ? "Operation was not named" : $"{op.Name}:{value_index}")}";
2326

2427
public int value_index { get; }
2528

@@ -222,7 +225,7 @@ public override string ToString()
222225
}
223226
}
224227

225-
return $"{name} {dtype} {rank} {string.Join(",", shape)}";
228+
return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}";
226229
}
227230

228231
public void Dispose()

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@ public abstract class Optimizer
1717
public string Name { get; set; }
1818
public double LearningRate { get; set; }
1919
public Tensor LearningRateTensor { get; set; }
20+
public bool _use_locking;
21+
public Dictionary<string, object> _slots;
22+
public Dictionary<string, object> _non_slot_dict;
23+
public Dictionary<string, object> _deferred_slot_restorations;
2024

2125
public Optimizer(double learning_rate, bool use_locking, string name = "")
2226
{
2327
if (String.IsNullOrEmpty(name))
2428
throw new NotImplementedException("Must specify the optimizer name");
2529

2630
Name = name;
31+
_use_locking = use_locking;
32+
// Dictionary of slots.
33+
_slots = new Dictionary<string, object>();
34+
_non_slot_dict = new Dictionary<string, object>();
35+
_deferred_slot_restorations = new Dictionary<string, object>();
2736
}
2837

2938
/// <summary>
@@ -68,7 +77,7 @@ public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
6877
break;
6978
}
7079

71-
var processors = var_list.Select(v => optimizer._get_processor(v));
80+
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
7281
var var_refs = processors.Select(x => x.target()).ToList();
7382

7483
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss,

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ private void _init_from_args(object initial_value,
7979
// have an issue if these other variables aren't initialized first by
8080
// using their initialized_value() method.
8181

82+
var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
83+
84+
if (!String.IsNullOrEmpty(caching_device))
85+
{
86+
87+
}
88+
else
89+
{
90+
91+
}
92+
8293
ops.add_to_collections(collections, this);
8394
}
8495
}

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp.Core;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45

@@ -33,5 +34,31 @@ public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name =
3334

3435
return new Tensor(_op, 0, dtype);
3536
}
37+
38+
/// <summary>
39+
/// Update 'ref' by assigning 'value' to it
40+
/// </summary>
41+
/// <param name="REF"></param>
42+
/// <param name="value"></param>
43+
/// <param name="validate_shape"></param>
44+
/// <param name="use_locking"></param>
45+
/// <param name="name"></param>
46+
public static Tensor assign(Tensor tensor, Tensor value,
47+
bool validate_shape = true,
48+
bool use_locking = true,
49+
string name = "")
50+
{
51+
var keywords = new Dictionary<string, object>();
52+
keywords.Add("ref", tensor);
53+
keywords.Add("value", value);
54+
keywords.Add("validate_shape", validate_shape);
55+
keywords.Add("use_locking", use_locking);
56+
57+
var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords);
58+
59+
var _result = _op.outputs[0];
60+
var _inputs_flat = _op.inputs;
61+
return _result;
62+
}
3663
}
3764
}

0 commit comments

Comments
 (0)