Skip to content

Commit b6e5899

Browse files
author
Tomasz Latkowski
committed
fixes in code
1 parent 51bf39b commit b6e5899

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[SELECTION]
22
num_features = 100
3-
method = 'fisher'
3+
method = fisher
44

55
[CLASSIFIER]
6-
hidden_sizes = 20, 10
6+
hidden_sizes = 20

experiments/experiment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
class Experiment:
1616

17-
def __init__(self, experiment_config, num_features, num_instances, classifier, dataset):
17+
def __init__(self, experiment_config, num_instances, classifier, dataset):
1818

1919
selection_method = methods[experiment_config['SELECTION']['method']]
20-
num_features = experiment_config['SELECTION']['num_features']
20+
num_features = int(experiment_config['SELECTION']['num_features'])
21+
hidden_sizes = int(experiment_config['CLASSIFIER']['hidden_sizes'])
2122

2223
with tf.name_scope('selection'):
2324
self.selection_wrapper = SelectionWrapper(dataset,
@@ -26,4 +27,4 @@ def __init__(self, experiment_config, num_features, num_instances, classifier, d
2627
num_features=num_features)
2728

2829
with tf.name_scope('classifier'):
29-
self.clf = NeuralNetworkClassifier(num_features, 20)
30+
self.clf = classifier(num_features, hidden_sizes)

requirements.txt

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1-
tqdm==4.15.0
2-
pandas==0.22.0
3-
numpy==1.14.1
1+
bleach==1.5.0
2+
enum34==1.1.6
3+
html5lib==0.9999999
4+
Markdown==2.6.11
5+
numpy==1.13.3
6+
pandas==0.19.0
7+
protobuf==3.5.1
8+
python-dateutil==2.6.1
9+
pytz==2018.3
10+
scikit-learn==0.19.1
11+
scipy==1.0.0
12+
six==1.11.0
13+
sklearn==0.0
414
tensorflow==1.4.0
5-
scikit_learn==0.19.1
15+
tensorflow-tensorboard==0.4.0
16+
tqdm==4.19.6
17+
Werkzeug==0.14.1

run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import tensorflow as tf
55
from tqdm import tqdm
66

7+
from experiments.classifier import NeuralNetworkClassifier
78
from experiments.dataset import Dataset
89
from experiments.experiment import Experiment
910
from utils.log_saver import LogSaver
1011

1112

1213
def run_experiment(experiment_config):
1314
dataset = Dataset('data/autism.tsv')
14-
1515
num_epochs = 1000
1616
eval_every = 10
1717

@@ -25,7 +25,7 @@ def run_experiment(experiment_config):
2525

2626
with tf.Graph().as_default() as graph:
2727

28-
experiment = Experiment(experiment_config, num_instances, None, data_train_fold)
28+
experiment = Experiment(experiment_config, num_instances, NeuralNetworkClassifier, data_train_fold)
2929

3030
with tf.Session() as session:
3131

@@ -64,7 +64,7 @@ def main():
6464

6565
args = parser.parse_args()
6666
experiment_config = configparser.ConfigParser()
67-
experiment_config.read('config/{}.ini'.format(args.experiment))
67+
experiment_config.read('config/experiments/{}.ini'.format(args.experiment))
6868

6969
run_experiment(experiment_config)
7070

0 commit comments

Comments
 (0)