99using NumSharp ;
1010using Tensorflow ;
1111using Tensorflow . Sessions ;
12+ using TensorFlowNET . Examples . Text ;
1213using TensorFlowNET . Examples . Utility ;
1314using static Tensorflow . Python ;
1415
@@ -24,24 +25,27 @@ public class CnnTextClassification : IExample
2425 public int ? DataLimit = null ;
2526 public bool IsImportingGraph { get ; set ; } = false ;
2627
27- private const string dataDir = "word_cnn " ;
28- private string dataFileName = "dbpedia_csv.tar.gz" ;
28+ const string dataDir = "cnn_text " ;
29+ string dataFileName = "dbpedia_csv.tar.gz" ;
2930
30- private const string TRAIN_PATH = "word_cnn /dbpedia_csv/train.csv";
31- private const string TEST_PATH = "word_cnn /dbpedia_csv/test.csv";
31+ string TRAIN_PATH = $ " { dataDir } /dbpedia_csv/train.csv";
32+ string TEST_PATH = $ " { dataDir } /dbpedia_csv/test.csv";
3233
33- private const int NUM_CLASS = 14 ;
34- private const int BATCH_SIZE = 64 ;
35- private const int NUM_EPOCHS = 10 ;
36- private const int WORD_MAX_LEN = 100 ;
37- private const int CHAR_MAX_LEN = 1014 ;
34+ int NUM_CLASS = 14 ;
35+ int BATCH_SIZE = 64 ;
36+ int NUM_EPOCHS = 10 ;
37+ int WORD_MAX_LEN = 100 ;
38+ int CHAR_MAX_LEN = 1014 ;
3839
39- protected float loss_value = 0 ;
40+ float loss_value = 0 ;
4041 double max_accuracy = 0 ;
4142
42- int vocabulary_size = 50000 ;
43+ int vocabulary_size = - 1 ;
4344 NDArray train_x , valid_x , train_y , valid_y ;
4445
46+ ITextModel textModel ;
47+ public string ModelName = "word_cnn" ; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
48+
4549 public bool Run ( )
4650 {
4751 PrepareData ( ) ;
@@ -68,7 +72,7 @@ public bool Run()
6872 return ( train_x , valid_x , train_y , valid_y ) ;
6973 }
7074
71- private static void FillWithShuffledLabels ( int [ ] [ ] x , int [ ] y , int [ ] [ ] shuffled_x , int [ ] shuffled_y , Random random , Dictionary < int , HashSet < int > > labels )
75+ private void FillWithShuffledLabels ( int [ ] [ ] x , int [ ] y , int [ ] [ ] shuffled_x , int [ ] shuffled_y , Random random , Dictionary < int , HashSet < int > > labels )
7276 {
7377 int i = 0 ;
7478 var label_keys = labels . Keys . ToArray ( ) ;
@@ -114,10 +118,8 @@ public void PrepareData()
114118
115119 Console . WriteLine ( "Building dataset..." ) ;
116120
117- int alphabet_size = 0 ;
118-
119121 var word_dict = DataHelpers . build_word_dict ( TRAIN_PATH ) ;
120- // vocabulary_size = len(word_dict);
122+ vocabulary_size = len ( word_dict ) ;
121123 var ( x , y ) = DataHelpers . build_word_dataset ( TRAIN_PATH , word_dict , WORD_MAX_LEN ) ;
122124
123125 Console . WriteLine ( "\t DONE " ) ;
@@ -155,83 +157,19 @@ public Graph BuildGraph()
155157 {
156158 var graph = tf . Graph ( ) . as_default ( ) ;
157159
158- var embedding_size = 128 ;
159- var learning_rate = 0.001f ;
160- var filter_sizes = new int [ 3 , 4 , 5 ] ;
161- var num_filters = 100 ;
162- var document_max_len = 100 ;
163-
164- var x = tf . placeholder ( tf . int32 , new TensorShape ( - 1 , document_max_len ) , name : "x" ) ;
165- var y = tf . placeholder ( tf . int32 , new TensorShape ( - 1 ) , name : "y" ) ;
166- var is_training = tf . placeholder ( tf . @bool , new TensorShape ( ) , name : "is_training" ) ;
167- var global_step = tf . Variable ( 0 , trainable : false ) ;
168- var keep_prob = tf . where ( is_training , 0.5f , 1.0f ) ;
169- Tensor x_emb = null ;
170-
171- with ( tf . name_scope ( "embedding" ) , scope =>
172- {
173- var init_embeddings = tf . random_uniform ( new int [ ] { vocabulary_size , embedding_size } ) ;
174- var embeddings = tf . get_variable ( "embeddings" , initializer : init_embeddings ) ;
175- x_emb = tf . nn . embedding_lookup ( embeddings , x ) ;
176- x_emb = tf . expand_dims ( x_emb , - 1 ) ;
177- } ) ;
178-
179- var pooled_outputs = new List < Tensor > ( ) ;
180- for ( int len = 0 ; len < filter_sizes . Rank ; len ++ )
160+ switch ( ModelName )
181161 {
182- int filter_size = filter_sizes . GetLength ( len ) ;
183- var conv = tf . layers . conv2d (
184- x_emb ,
185- filters : num_filters ,
186- kernel_size : new int [ ] { filter_size , embedding_size } ,
187- strides : new int [ ] { 1 , 1 } ,
188- padding : "VALID" ,
189- activation : tf . nn . relu ( ) ) ;
190-
191- var pool = tf . layers . max_pooling2d (
192- conv ,
193- pool_size : new [ ] { document_max_len - filter_size + 1 , 1 } ,
194- strides : new [ ] { 1 , 1 } ,
195- padding : "VALID" ) ;
196-
197- pooled_outputs . Add ( pool ) ;
162+ case "word_cnn" :
163+ textModel = new WordCnn ( vocabulary_size , WORD_MAX_LEN , NUM_CLASS ) ;
164+ break ;
198165 }
199166
200- var h_pool = tf . concat ( pooled_outputs , 3 ) ;
201- var h_pool_flat = tf . reshape ( h_pool , new TensorShape ( - 1 , num_filters * filter_sizes . Rank ) ) ;
202- Tensor h_drop = null ;
203- with ( tf . name_scope ( "dropout" ) , delegate
204- {
205- h_drop = tf . nn . dropout ( h_pool_flat , keep_prob ) ;
206- } ) ;
207-
208- Tensor logits = null ;
209- Tensor predictions = null ;
210- with ( tf . name_scope ( "output" ) , delegate
211- {
212- logits = tf . layers . dense ( h_drop , NUM_CLASS ) ;
213- predictions = tf . argmax ( logits , - 1 , output_type : tf . int32 ) ;
214- } ) ;
215-
216- with ( tf . name_scope ( "loss" ) , delegate
217- {
218- var sscel = tf . nn . sparse_softmax_cross_entropy_with_logits ( logits : logits , labels : y ) ;
219- var loss = tf . reduce_mean ( sscel ) ;
220- var adam = tf . train . AdamOptimizer ( learning_rate ) ;
221- var optimizer = adam . minimize ( loss , global_step : global_step ) ;
222- } ) ;
223-
224- with ( tf . name_scope ( "accuracy" ) , delegate
225- {
226- var correct_predictions = tf . equal ( predictions , y ) ;
227- var accuracy = tf . reduce_mean ( tf . cast ( correct_predictions , TF_DataType . TF_FLOAT ) , name : "accuracy" ) ;
228- } ) ;
229-
230167 return graph ;
231168 }
232169
233- private bool Train ( Session sess , Graph graph )
170+ public void Train ( Session sess )
234171 {
172+ var graph = tf . get_default_graph ( ) ;
235173 var stopwatch = Stopwatch . StartNew ( ) ;
236174
237175 sess . run ( tf . global_variables_initializer ( ) ) ;
@@ -263,10 +201,7 @@ private bool Train(Session sess, Graph graph)
263201 loss_value = result [ 2 ] ;
264202 var step = ( int ) result [ 1 ] ;
265203 if ( step % 10 == 0 )
266- {
267- var estimate = TimeSpan . FromSeconds ( ( stopwatch . Elapsed . TotalSeconds / i ) * total ) ;
268- Console . WriteLine ( $ "Training on batch { i } /{ total } loss: { loss_value } . Estimated training time: { estimate } ") ;
269- }
204+ Console . WriteLine ( $ "Training on batch { i } /{ total } loss: { loss_value . ToString ( "0.0000" ) } .") ;
270205
271206 if ( step % 100 == 0 )
272207 {
@@ -289,7 +224,7 @@ private bool Train(Session sess, Graph graph)
289224
290225 var valid_accuracy = sum_accuracy / cnt ;
291226
292- print ( $ "\n Validation Accuracy = { valid_accuracy } \n ") ;
227+ print ( $ "\n Validation Accuracy = { valid_accuracy . ToString ( "P" ) } \n ") ;
293228
294229 // Save model
295230 if ( valid_accuracy > max_accuracy )
@@ -300,13 +235,6 @@ private bool Train(Session sess, Graph graph)
300235 }
301236 }
302237 }
303-
304- return max_accuracy > 0.9 ;
305- }
306-
307- public void Train ( Session sess )
308- {
309- Train ( sess , sess . graph ) ;
310238 }
311239
312240 public void Predict ( Session sess )
0 commit comments