Skip to content

Commit 89f305c

Browse files
committed
fix shape issue for IndexedSlice
1 parent 6af53e6 commit 89f305c

File tree

7 files changed

+41
-10
lines changed

7 files changed

+41
-10
lines changed

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
153153
{
154154
if (in_grad != null)
155155
{
156-
if (in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE)
156+
if (in_grad is Tensor &&
157+
in_grad.Tag == null && // maybe a IndexedSlice
158+
t_in.dtype != TF_DataType.TF_RESOURCE)
157159
{
158160
in_grad.shape = t_in.shape;
159161
}

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation
4343
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x)));
4444
}
4545

46-
return ret.ToArray();
46+
return ret.OrderBy(x => x.op.name).ToArray();
4747
}
4848

4949
/// <summary>

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,12 @@ public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[]
245245
// If a names ends with a '/' it is a "name scope" and we use it as-is,
246246
// after removing the trailing '/'.
247247
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
248-
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
248+
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
249+
250+
var input_ops = inputs.Select(x => x.op).ToArray();
251+
if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape")
252+
;
249253

250-
var input_ops = inputs.Select(x => x.op).ToArray();
251254
var control_inputs = _control_dependencies_for_inputs(input_ops);
252255

253256
var op = new Operation(node_def,

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
using System;
1+
#if GRAPH_SERIALIZE
2+
using Newtonsoft.Json;
3+
#endif
4+
using System;
25
using System.Collections.Generic;
36
using System.Linq;
47
using System.Runtime.InteropServices;
@@ -14,7 +17,9 @@ public partial class Operation
1417

1518
private Tensor[] _outputs;
1619
public Tensor[] outputs => _outputs;
17-
20+
#if GRAPH_SERIALIZE
21+
[JsonIgnore]
22+
#endif
1823
public Tensor output => _outputs.FirstOrDefault();
1924

2025
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
3030

3131
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
3232
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
33-
<DefineConstants>TRACE;DEBUG</DefineConstants>
33+
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants>
3434
</PropertyGroup>
3535

3636
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -54,6 +54,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
5454
<ItemGroup>
5555
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
5656
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" />
57+
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
5758
<PackageReference Include="NumSharp" Version="0.10.3" />
5859
</ItemGroup>
5960

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,18 @@ public partial class Tensor : IDisposable, ITensorOrOperation
5050

5151
private TF_DataType _dtype = TF_DataType.DtInvalid;
5252
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
53+
#if GRAPH_SERIALIZE
54+
[JsonIgnore]
55+
#endif
5356
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
57+
#if GRAPH_SERIALIZE
58+
[JsonIgnore]
59+
#endif
5460
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
5561
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
62+
#if GRAPH_SERIALIZE
63+
[JsonIgnore]
64+
#endif
5665
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
5766
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
5867

@@ -61,6 +70,9 @@ public partial class Tensor : IDisposable, ITensorOrOperation
6170
/// <summary>
6271
/// used for keep other pointer when do implicit operating
6372
/// </summary>
73+
#if GRAPH_SERIALIZE
74+
[JsonIgnore]
75+
#endif
6476
public object Tag { get; set; }
6577

6678
public int[] shape
@@ -131,7 +143,9 @@ public int rank
131143
}
132144
}
133145
}
134-
146+
#if GRAPH_SERIALIZE
147+
[JsonIgnore]
148+
#endif
135149
public int NDims => rank;
136150

137151
public string Device => op.Device;

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ public bool Run()
6262
train_y = y[new Slice(stop: train_size)];
6363
valid_y = y[new Slice(start: train_size)];
6464
Console.WriteLine("\tDONE");
65+
66+
train_x = np.Load<int[,]>(Path.Join("word_cnn", "train_x.npy"));
67+
valid_x = np.Load<int[,]>(Path.Join("word_cnn", "valid_x.npy"));
68+
train_y = np.Load<int[]>(Path.Join("word_cnn", "train_y.npy"));
69+
valid_y = np.Load<int[]>(Path.Join("word_cnn", "valid_y.npy"));
70+
6571
return (train_x, valid_x, train_y, valid_y);
6672
}
6773

@@ -114,7 +120,7 @@ public void PrepareData()
114120
int alphabet_size = 0;
115121

116122
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
117-
vocabulary_size = len(word_dict);
123+
//vocabulary_size = len(word_dict);
118124
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
119125

120126
Console.WriteLine("\tDONE ");
@@ -305,7 +311,7 @@ private bool Train(Session sess, Graph graph)
305311
public bool Train()
306312
{
307313
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
308-
314+
string json = JsonConvert.SerializeObject(graph, Formatting.Indented);
309315
return with(tf.Session(graph), sess => Train(sess, graph));
310316
}
311317

0 commit comments

Comments
 (0)