1+ import configparser
2+ from argparse import ArgumentParser
3+
14import tensorflow as tf
25from tqdm import tqdm
36
47from experiments .dataset import Dataset
58from experiments .experiment import Experiment
6- from methods .selection import fisher
79from utils .log_saver import LogSaver
810
9- dataset = Dataset ('data/autism.tsv' )
1011
11- num_features = 100
12- num_epochs = 1000
13- eval_every = 10
12+ def run_experiment (experiment_config ):
13+ dataset = Dataset ('data/autism.tsv' )
14+
15+ num_epochs = 1000
16+ eval_every = 10
17+
18+ for fold_id , (train_idxs , test_idxs ) in dataset .cross_validation ():
19+
20+ data_train_fold = dataset .get_data (train_idxs )
21+ num_instances , labels_train_fold = dataset .get_labels (train_idxs )
22+
23+ data_test_fold = dataset .get_data (test_idxs )
24+ _ , labels_test_fold = dataset .get_labels (test_idxs )
25+
26+ with tf .Graph ().as_default () as graph :
27+
28+ experiment = Experiment (experiment_config , num_instances , None , data_train_fold )
29+
30+ with tf .Session () as session :
1431
15- for fold_id , (train_idxs , test_idxs ) in dataset .cross_validation ():
32+ global_step = 0
33+ session .run (tf .global_variables_initializer ())
1634
17- data_train_fold = dataset .get_data (train_idxs )
18- num_instances , labels_train_fold = dataset .get_labels (train_idxs )
35+ log_saver = LogSaver ('logs' , 'fisher_fold{}' .format (fold_id ), session .graph )
1936
20- data_test_fold = dataset . get_data ( test_idxs )
21- _ , labels_test_fold = dataset . get_labels ( test_idxs )
37+ train_selected_data = session . run ( experiment . selection_wrapper . selected_data )
38+ test_selected_data = session . run ( experiment . selection_wrapper . select ( data_test_fold ) )
2239
23- with tf . Graph (). as_default () as graph :
40+ tqdm_iter = tqdm ( range ( num_epochs ), desc = 'Epochs' )
2441
25- experiment = Experiment (fisher , num_features , num_instances , None , data_train_fold )
42+ for epoch in tqdm_iter :
43+ feed_dict = {experiment .clf .x : train_selected_data , experiment .clf .y : labels_train_fold }
44+ loss , _ = session .run ([experiment .clf .loss , experiment .clf .opt ],
45+ feed_dict = feed_dict )
2646
27- with tf .Session () as session :
47+ if epoch % eval_every == 0 :
48+ summary = session .run (experiment .clf .summary_op , feed_dict = feed_dict )
49+ log_saver .log_train (summary , epoch )
2850
29- global_step = 0
30- session .run (tf .global_variables_initializer ())
51+ feed_dict = {experiment .clf .x : test_selected_data , experiment .clf .y : labels_test_fold }
52+ summary = session .run (experiment .clf .summary_op , feed_dict = feed_dict )
53+ log_saver .log_test (summary , epoch )
3154
32- log_saver = LogSaver ( 'logs' , 'fisher_fold{ }' .format (fold_id ), session . graph )
55+ tqdm_iter . set_postfix ( loss = '{:.2f }' .format (float ( loss )), epoch = epoch )
3356
34- train_selected_data = session .run (experiment .selection_wrapper .selected_data )
35- test_selected_data = session .run (experiment .selection_wrapper .select (data_test_fold ))
3657
37- tqdm_iter = tqdm (range (num_epochs ), desc = 'Epochs' )
58+ def main ():
59+ parser = ArgumentParser ()
60+ parser .add_argument ('experiment' ,
61+ default = 'simple_experiment' ,
62+ choices = ['simple_experiment' ],
63+ help = 'model used during training (default: %(default))' )
3864
39- for epoch in tqdm_iter :
40- feed_dict = {experiment .clf .x : train_selected_data , experiment .clf .y : labels_train_fold }
41- loss , _ = session .run ([experiment .clf .loss , experiment .clf .opt ],
42- feed_dict = feed_dict )
65+ args = parser .parse_args ()
66+ experiment_config = configparser .ConfigParser ()
67+ experiment_config .read ('config/{}.ini' .format (args .experiment ))
4368
44- if epoch % eval_every == 0 :
45- summary = session .run (experiment .clf .summary_op , feed_dict = feed_dict )
46- log_saver .log_train (summary , epoch )
69+ run_experiment (experiment_config )
4770
48- feed_dict = {experiment .clf .x : test_selected_data , experiment .clf .y : labels_test_fold }
49- summary = session .run (experiment .clf .summary_op , feed_dict = feed_dict )
50- log_saver .log_test (summary , epoch )
5171
52- tqdm_iter .set_postfix (loss = '{:.2f}' .format (float (loss )), epoch = epoch )
72+ if __name__ == '__main__' :
73+ main ()
0 commit comments