Skip to content

Commit 477d03d

Browse files
committed
remove order of _control_dependencies_for_inputs.
1 parent 89f305c commit 477d03d

File tree

4 files changed

+6
-12
lines changed

4 files changed

+6
-12
lines changed

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation
4343
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x)));
4444
}
4545

46-
return ret.OrderBy(x => x.op.name).ToArray();
46+
return ret.ToArray();
4747
}
4848

4949
/// <summary>

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,6 @@ public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[]
248248
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
249249

250250
var input_ops = inputs.Select(x => x.op).ToArray();
251-
if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape")
252-
;
253-
254251
var control_inputs = _control_dependencies_for_inputs(input_ops);
255252

256253
var op = new Operation(node_def,

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
3030

3131
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
3232
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
33-
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants>
33+
<DefineConstants>TRACE;DEBUG</DefineConstants>
3434
</PropertyGroup>
3535

3636
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -54,7 +54,6 @@ Docs: https://tensorflownet.readthedocs.io</Description>
5454
<ItemGroup>
5555
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
5656
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" />
57-
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
5857
<PackageReference Include="NumSharp" Version="0.10.3" />
5958
</ItemGroup>
6059

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public class CnnTextClassification : IExample
2222
public bool Enabled { get; set; } = true;
2323
public string Name => "CNN Text Classification";
2424
public int? DataLimit = null;
25-
public bool IsImportingGraph { get; set; } = false;
25+
public bool IsImportingGraph { get; set; } = true;
2626

2727
private const string dataDir = "word_cnn";
2828
private string dataFileName = "dbpedia_csv.tar.gz";
@@ -44,9 +44,7 @@ public bool Run()
4444
{
4545
PrepareData();
4646

47-
Train();
48-
49-
return true;
47+
return Train();
5048
}
5149

5250
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
@@ -305,13 +303,13 @@ private bool Train(Session sess, Graph graph)
305303
}
306304
}
307305

308-
return false;
306+
return max_accuracy > 0.8;
309307
}
310308

311309
public bool Train()
312310
{
313311
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
314-
string json = JsonConvert.SerializeObject(graph, Formatting.Indented);
312+
// string json = JsonConvert.SerializeObject(graph, Formatting.Indented);
315313
return with(tf.Session(graph), sess => Train(sess, graph));
316314
}
317315

0 commit comments

Comments
 (0)