Skip to content

Commit 1876cc9

Browse files
committed
Fix Operation.get_attr #115
1 parent fd789d4 commit 1876cc9

File tree

5 files changed

+55
-12
lines changed

5 files changed

+55
-12
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) st
1212
![tensors_flowing](docs/assets/tensors_flowing.gif)
1313

1414
### How to use
15-
Download the pre-compiled dll [here](tensorflow.so) and place it in the working folder.
16-
This is only need for Linux and Mac OS, and already packed for Windows.
17-
1815
Install TensorFlow.NET through NuGet.
1916
```sh
2017
PM> Install-Package TensorFlow.NET
2118
```
2219

20+
If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflow.so) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows.
21+
2322
Import tensorflow.net.
23+
2424
```cs
2525
using Tensorflow;
2626
```
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Eager
6+
{
7+
public class Execute
8+
{
9+
public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
10+
{
11+
if (inputs == null)
12+
inputs = new Tensor[0];
13+
14+
pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name);
15+
}
16+
}
17+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Eager
7+
{
8+
/// <summary>
9+
/// python\eager\pywrap_tfe_src.cc
10+
/// </summary>
11+
public class pywrap_tfe_src
12+
{
13+
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
14+
{
15+
16+
}
17+
}
18+
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
113113
op_def = g.GetOpDef(node_def.Op);
114114

115115
_handle = ops._create_c_op(g, node_def, inputs);
116-
116+
117117
_outputs = new Tensor[NumOutputs];
118118
output_types = new TF_DataType[NumOutputs];
119119

@@ -128,21 +128,26 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
128128

129129
public object get_attr(string name)
130130
{
131-
object ret = null;
131+
AttrValue x = null;
132132

133133
var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };
134134

135+
using (var buf = new Buffer())
136+
{
137+
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
138+
status.Check(true);
139+
x = AttrValue.Parser.ParseFrom(buf);
140+
}
141+
135142
switch (name)
136143
{
137144
case "dtype":
138-
ret = _outputs[0];
139-
break;
145+
return x.Type;
140146
case "shape":
141-
ret = new TensorShapeProto();
142-
break;
147+
return x.Shape;
148+
default:
149+
throw new NotImplementedException($"{name}");
143150
}
144-
145-
return ret;
146151
}
147152

148153
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
using System.IO;
44
using System.Text;
55
using Tensorflow;
6+
using Tensorflow.Eager;
67

78
namespace Tensorflow
89
{
910
public static class gen_array_ops
1011
{
1112
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
13+
public static Execute _execute = new Execute();
1214

13-
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
15+
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = "")
1416
{
1517
var keywords = new Dictionary<string, object>();
1618
keywords.Add("dtype", dtype);
@@ -24,6 +26,7 @@ public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
2426
_attrs["dtype"] = _op.get_attr("dtype");
2527
_attrs["shape"] = _op.get_attr("shape");
2628

29+
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
2730
return new Tensor(_op, 0, dtype);
2831
}
2932

0 commit comments

Comments
 (0)