1- import numpy as np
21import tensorflow as tf
3- from utils .log_saver import LogSaver
4- from experiments .experiment import ExperimentModel
5- from methods .selection import fisher
62from tqdm import tqdm
7- from sklearn .model_selection import StratifiedKFold
8- from utils .data_reader import read
93
4+ from experiments .dataset import Dataset
5+ from experiments .experiment import Experiment
6+ from methods .selection import fisher
7+ from utils .log_saver import LogSaver
108
11- data_fn = 'data/autism.tsv'
12- data = read (data_fn )
9+ dataset = Dataset ('data/autism.tsv' )
1310
1411num_features = 100
1512num_epochs = 1000
1613eval_every = 10
1714
18- labels = np .concatenate ([np .ones (82 , dtype = np .float64 ), np .zeros (64 , dtype = np .float64 )])
19- labels = np .reshape (labels , (- 1 , 1 ))
20-
21-
22- skf = StratifiedKFold (n_splits = 10 )
23-
24- for fold_id , (train_idxs , test_idxs ) in enumerate (skf .split (data , labels .reshape (146 ))):
15+ for fold_id , (train_idxs , test_idxs ) in dataset .cross_validation ():
2516
26- data_train_fold = data [train_idxs , :]
27- labels_train_fold = labels [train_idxs ]
28- num_instances = [int (sum (labels_train_fold == 0 )), int (sum (labels_train_fold == 1 ))]
17+ data_train_fold = dataset .get_data (train_idxs )
18+ num_instances , labels_train_fold = dataset .get_labels (train_idxs )
2919
30- data_test_fold = data [ test_idxs , :]
31- labels_test_fold = labels [ test_idxs ]
20+ data_test_fold = dataset . get_data ( test_idxs )
21+ _ , labels_test_fold = dataset . get_labels ( test_idxs )
3222
3323 with tf .Graph ().as_default () as graph :
3424
35- model = ExperimentModel (fisher , num_features , num_instances , None , data_train_fold )
25+ experiment = Experiment (fisher , num_features , num_instances , None , data_train_fold )
3626
3727 with tf .Session () as session :
3828
4131
4232 log_saver = LogSaver ('logs' , 'fisher_fold{}' .format (fold_id ), session .graph )
4333
44- train_selected_data = session .run (model .selection_wrapper .selected_data )
45- test_selected_data = session .run (model .selection_wrapper .select (data_test_fold ))
34+ train_selected_data = session .run (experiment .selection_wrapper .selected_data )
35+ test_selected_data = session .run (experiment .selection_wrapper .select (data_test_fold ))
4636
4737 tqdm_iter = tqdm (range (num_epochs ), desc = 'Epochs' )
4838
4939 for epoch in tqdm_iter :
50- feed_dict = {model .clf .x : train_selected_data , model .clf .y : labels_train_fold }
51- loss , _ , summary = session .run ([model .clf .loss , model .clf .opt , model .clf .summary_op ], feed_dict = feed_dict )
40+ feed_dict = {experiment .clf .x : train_selected_data , experiment .clf .y : labels_train_fold }
41+ loss , _ = session .run ([experiment .clf .loss , experiment .clf .opt ],
42+ feed_dict = feed_dict )
5243
5344 if epoch % eval_every == 0 :
54- summary = session .run (model .clf .summary_op , feed_dict = feed_dict )
45+ summary = session .run (experiment .clf .summary_op , feed_dict = feed_dict )
5546 log_saver .log_train (summary , epoch )
5647
57- feed_dict = {model .clf .x : test_selected_data , model .clf .y : labels_test_fold }
58- summary = session .run (model .clf .summary_op , feed_dict = feed_dict )
48+ feed_dict = {experiment .clf .x : test_selected_data , experiment .clf .y : labels_test_fold }
49+ summary = session .run (experiment .clf .summary_op , feed_dict = feed_dict )
5950 log_saver .log_test (summary , epoch )
6051
61- tqdm_iter .set_postfix (loss = '{:.2f}' .format (float (loss )), epoch = epoch )
52+ tqdm_iter .set_postfix (loss = '{:.2f}' .format (float (loss )), epoch = epoch )
0 commit comments