@@ -49,8 +49,6 @@ public bool Run()
4949
5050 public Graph BuildGraph ( )
5151 {
52- var g = tf . Graph ( ) ;
53-
5452 // Placeholders for inputs (x) and outputs(y)
5553 x = tf . placeholder ( tf . float32 , shape : ( - 1 , img_size_flat ) , name : "X" ) ;
5654 y = tf . placeholder ( tf . float32 , shape : ( - 1 , n_classes ) , name : "Y" ) ;
@@ -60,15 +58,16 @@ public Graph BuildGraph()
6058 // Create a fully-connected layer with n_classes nodes as output layer
6159 var output_logits = fc_layer ( fc1 , n_classes , "OUT" , use_relu : false ) ;
6260 // Define the loss function, optimizer, and accuracy
63- loss = tf . reduce_mean ( tf . nn . softmax_cross_entropy_with_logits_v2 ( labels : y , logits : output_logits ) , name : "loss" ) ;
61+ var logits = tf . nn . softmax_cross_entropy_with_logits ( labels : y , logits : output_logits ) ;
62+ loss = tf . reduce_mean ( logits , name : "loss" ) ;
6463 optimizer = tf . train . AdamOptimizer ( learning_rate : learning_rate , name : "Adam-op" ) . minimize ( loss ) ;
6564 var correct_prediction = tf . equal ( tf . argmax ( output_logits , 1 ) , tf . argmax ( y , 1 ) , name : "correct_pred" ) ;
6665 accuracy = tf . reduce_mean ( tf . cast ( correct_prediction , tf . float32 ) , name : "accuracy" ) ;
6766
6867 // Network predictions
6968 var cls_prediction = tf . argmax ( output_logits , axis : 1 , name : "predictions" ) ;
7069
71- return g ;
70+ return tf . get_default_graph ( ) ;
7271 }
7372
7473 private Tensor fc_layer ( Tensor x , int num_units , string name , bool use_relu = true )
@@ -93,16 +92,10 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
9392 return layer ;
9493 }
9594
96- public Graph ImportGraph ( )
97- {
98- throw new NotImplementedException ( ) ;
99- }
100-
101- public bool Predict ( )
102- {
103- throw new NotImplementedException ( ) ;
104- }
95+ public Graph ImportGraph ( ) => throw new NotImplementedException ( ) ;
10596
97+ public bool Predict ( ) => throw new NotImplementedException ( ) ;
98+
10699 public void PrepareData ( )
107100 {
108101 mnist = MnistDataSet . read_data_sets ( "mnist" , one_hot : true ) ;
@@ -112,7 +105,6 @@ public bool Train()
112105 {
113106 // Number of training iterations in each epoch
114107 var num_tr_iter = mnist . train . labels . len / batch_size ;
115-
116108 return with ( tf . Session ( ) , sess =>
117109 {
118110 var init = tf . global_variables_initializer ( ) ;
@@ -153,10 +145,9 @@ public bool Train()
153145 print ( "---------------------------------------------------------" ) ;
154146 print ( $ "Epoch: { epoch + 1 } , validation loss: { loss_val . ToString ( "0.0000" ) } , validation accuracy: { accuracy_val . ToString ( "P" ) } ") ;
155147 print ( "---------------------------------------------------------" ) ;
156-
157148 }
158149
159- return accuracy_val > 0.9 ;
150+ return accuracy_val > 0.95 ;
160151 } ) ;
161152 }
162153
0 commit comments