1313
1414num_features = 100
1515num_epochs = 1000
16+ eval_every = 10
1617
1718labels = np .concatenate ([np .ones (82 , dtype = np .float64 ), np .zeros (64 , dtype = np .float64 )])
1819labels = np .reshape (labels , (- 1 , 1 ))
2223
2324for fold_id , (train_idxs , test_idxs ) in enumerate (skf .split (data , labels .reshape (146 ))):
2425
25- data_fold = data [train_idxs , :]
26- labels_fold = labels [train_idxs ]
27- num_instances = [int (sum (labels_fold == 0 )), int (sum (labels_fold == 1 ))]
26+ data_train_fold = data [train_idxs , :]
27+ labels_train_fold = labels [train_idxs ]
28+ num_instances = [int (sum (labels_train_fold == 0 )), int (sum (labels_train_fold == 1 ))]
29+
30+ data_test_fold = data [test_idxs , :]
31+ labels_test_fold = labels [test_idxs ]
2832
2933 with tf .Graph ().as_default () as graph :
3034
31- model = ExperimentModel (fisher , num_features , num_instances , None , data_fold )
35+ model = ExperimentModel (fisher , num_features , num_instances , None , data_train_fold )
3236
3337 with tf .Session () as session :
3438
3741
3842 log_saver = LogSaver ('logs' , 'fisher_fold{}' .format (fold_id ), session .graph )
3943
40- selected_data = session .run (model .selection_wrapper .selected_features )
44+ train_selected_data = session .run (model .selection_wrapper .selected_data )
45+ test_selected_data = session .run (model .selection_wrapper .select (data_test_fold ))
4146
4247 tqdm_iter = tqdm (range (num_epochs ), desc = 'Epochs' )
4348
4449 for epoch in tqdm_iter :
45- feed_dict = {model .clf .x : selected_data , model .clf .y : labels_fold }
50+ feed_dict = {model .clf .x : train_selected_data , model .clf .y : labels_train_fold }
4651 loss , _ , summary = session .run ([model .clf .loss , model .clf .opt , model .clf .summary_op ], feed_dict = feed_dict )
47- log_saver .log_train (summary , epoch )
52+
53+ if epoch % eval_every == 0 :
54+ summary = session .run (model .clf .summary_op , feed_dict = feed_dict )
55+ log_saver .log_train (summary , epoch )
56+
57+ feed_dict = {model .clf .x : test_selected_data , model .clf .y : labels_test_fold }
58+ summary = session .run (model .clf .summary_op , feed_dict = feed_dict )
59+ log_saver .log_test (summary , epoch )
60+
4861 tqdm_iter .set_postfix (loss = '{:.2f}' .format (float (loss )), epoch = epoch )
0 commit comments