Skip to content

Commit 51bf39b

Browse files
author
Tomasz Latkowski
committed
added simple experiment config
1 parent b19e580 commit 51bf39b

File tree

4 files changed

+70
-33
lines changed

4 files changed

+70
-33
lines changed

config/main.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[TRAINING]
2+
num_epochs = 1000
3+
eval_every = 10

experiments/experiment.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@
22

33
from experiments.classifier import NeuralNetworkClassifier
44
from methods.selection_wrapper import SelectionWrapper
5+
from methods.selection import fisher, feature_correlation_with_class, t_test, random
6+
7+
methods = {
8+
'fisher': fisher,
9+
'corr': feature_correlation_with_class,
10+
'ttest': t_test,
11+
'random': random
12+
}
513

614

715
class Experiment:
816

9-
def __init__(self, selection_method, num_features, num_instances, classifier, dataset):
17+
def __init__(self, experiment_config, num_features, num_instances, classifier, dataset):
18+
19+
selection_method = methods[experiment_config['SELECTION']['method']]
20+
num_features = experiment_config['SELECTION']['num_features']
21+
1022
with tf.name_scope('selection'):
1123
self.selection_wrapper = SelectionWrapper(dataset,
1224
num_instances=num_instances,

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
tqdm==4.15.0
12
pandas==0.22.0
3+
numpy==1.14.1
24
tensorflow==1.4.0
3-
numpy==1.13.3
4-
tqdm==4.19.5
5+
scikit_learn==0.19.1

run.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,73 @@
1+
import configparser
2+
from argparse import ArgumentParser
3+
14
import tensorflow as tf
25
from tqdm import tqdm
36

47
from experiments.dataset import Dataset
58
from experiments.experiment import Experiment
6-
from methods.selection import fisher
79
from utils.log_saver import LogSaver
810

9-
dataset = Dataset('data/autism.tsv')
1011

11-
num_features = 100
12-
num_epochs = 1000
13-
eval_every = 10
12+
def run_experiment(experiment_config):
13+
dataset = Dataset('data/autism.tsv')
14+
15+
num_epochs = 1000
16+
eval_every = 10
17+
18+
for fold_id, (train_idxs, test_idxs) in dataset.cross_validation():
19+
20+
data_train_fold = dataset.get_data(train_idxs)
21+
num_instances, labels_train_fold = dataset.get_labels(train_idxs)
22+
23+
data_test_fold = dataset.get_data(test_idxs)
24+
_, labels_test_fold = dataset.get_labels(test_idxs)
25+
26+
with tf.Graph().as_default() as graph:
27+
28+
experiment = Experiment(experiment_config, num_instances, None, data_train_fold)
29+
30+
with tf.Session() as session:
1431

15-
for fold_id, (train_idxs, test_idxs) in dataset.cross_validation():
32+
global_step = 0
33+
session.run(tf.global_variables_initializer())
1634

17-
data_train_fold = dataset.get_data(train_idxs)
18-
num_instances, labels_train_fold = dataset.get_labels(train_idxs)
35+
log_saver = LogSaver('logs', 'fisher_fold{}'.format(fold_id), session.graph)
1936

20-
data_test_fold = dataset.get_data(test_idxs)
21-
_, labels_test_fold = dataset.get_labels(test_idxs)
37+
train_selected_data = session.run(experiment.selection_wrapper.selected_data)
38+
test_selected_data = session.run(experiment.selection_wrapper.select(data_test_fold))
2239

23-
with tf.Graph().as_default() as graph:
40+
tqdm_iter = tqdm(range(num_epochs), desc='Epochs')
2441

25-
experiment = Experiment(fisher, num_features, num_instances, None, data_train_fold)
42+
for epoch in tqdm_iter:
43+
feed_dict = {experiment.clf.x: train_selected_data, experiment.clf.y: labels_train_fold}
44+
loss, _ = session.run([experiment.clf.loss, experiment.clf.opt],
45+
feed_dict=feed_dict)
2646

27-
with tf.Session() as session:
47+
if epoch % eval_every == 0:
48+
summary = session.run(experiment.clf.summary_op, feed_dict=feed_dict)
49+
log_saver.log_train(summary, epoch)
2850

29-
global_step = 0
30-
session.run(tf.global_variables_initializer())
51+
feed_dict = {experiment.clf.x: test_selected_data, experiment.clf.y: labels_test_fold}
52+
summary = session.run(experiment.clf.summary_op, feed_dict=feed_dict)
53+
log_saver.log_test(summary, epoch)
3154

32-
log_saver = LogSaver('logs', 'fisher_fold{}'.format(fold_id), session.graph)
55+
tqdm_iter.set_postfix(loss='{:.2f}'.format(float(loss)), epoch=epoch)
3356

34-
train_selected_data = session.run(experiment.selection_wrapper.selected_data)
35-
test_selected_data = session.run(experiment.selection_wrapper.select(data_test_fold))
3657

37-
tqdm_iter = tqdm(range(num_epochs), desc='Epochs')
58+
def main():
59+
parser = ArgumentParser()
60+
parser.add_argument('experiment',
61+
default='simple_experiment',
62+
choices=['simple_experiment'],
63+
help='model used during training (default: %(default))')
3864

39-
for epoch in tqdm_iter:
40-
feed_dict = {experiment.clf.x: train_selected_data, experiment.clf.y: labels_train_fold}
41-
loss, _ = session.run([experiment.clf.loss, experiment.clf.opt],
42-
feed_dict=feed_dict)
65+
args = parser.parse_args()
66+
experiment_config = configparser.ConfigParser()
67+
experiment_config.read('config/{}.ini'.format(args.experiment))
4368

44-
if epoch % eval_every == 0:
45-
summary = session.run(experiment.clf.summary_op, feed_dict=feed_dict)
46-
log_saver.log_train(summary, epoch)
69+
run_experiment(experiment_config)
4770

48-
feed_dict = {experiment.clf.x: test_selected_data, experiment.clf.y: labels_test_fold}
49-
summary = session.run(experiment.clf.summary_op, feed_dict=feed_dict)
50-
log_saver.log_test(summary, epoch)
5171

52-
tqdm_iter.set_postfix(loss='{:.2f}'.format(float(loss)), epoch=epoch)
72+
if __name__ == '__main__':
73+
main()

0 commit comments

Comments
 (0)