Skip to content

Commit c6f48e8

Browse files
author
Tomasz Latkowski
committed
added test part
1 parent a782db7 commit c6f48e8

File tree

3 files changed

+25
-25
lines changed

3 files changed

+25
-25
lines changed

methods/selection.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,6 @@
11
import tensorflow as tf
22

33

4-
def selection_wrapper(data, num_instances, selection_method=None, num_features=None):
5-
if data is None:
6-
raise ValueError('Provide data to make selection.')
7-
8-
if selection_method is None:
9-
raise ValueError('Provide selection method.')
10-
11-
if num_features is None:
12-
data = tf.convert_to_tensor(data)
13-
num_features = data.get_shape().as_list()[-1]
14-
15-
values, indices = selection_method(data, num_instances, num_features)
16-
selected_features = tf.gather(data, indices, axis=1)
17-
return values, selected_features
18-
19-
204
def fisher(data, num_instances: list, top_k_features=2):
215
"""
226
Performs Fisher feature selection method according to the following formula:

methods/selection_wrapper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,8 @@ def __init__(self, data, num_instances, selection_method=None, num_features=None
1414
data = tf.convert_to_tensor(data)
1515
num_features = data.get_shape().as_list()[-1]
1616

17-
self.values, indices = selection_method(data, num_instances, num_features)
18-
self.selected_features = tf.gather(data, indices, axis=1)
17+
self.values, self.indices = selection_method(data, num_instances, num_features)
18+
self.selected_data = tf.gather(data, self.indices, axis=1)
19+
20+
def select(self, data):
21+
return tf.gather(data, self.indices, axis=1)

run.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
num_features = 100
1515
num_epochs = 1000
16+
eval_every = 10
1617

1718
labels = np.concatenate([np.ones(82, dtype=np.float64), np.zeros(64, dtype=np.float64)])
1819
labels = np.reshape(labels, (-1, 1))
@@ -22,13 +23,16 @@
2223

2324
for fold_id, (train_idxs, test_idxs) in enumerate(skf.split(data, labels.reshape(146))):
2425

25-
data_fold = data[train_idxs, :]
26-
labels_fold = labels[train_idxs]
27-
num_instances = [int(sum(labels_fold == 0)), int(sum(labels_fold == 1))]
26+
data_train_fold = data[train_idxs, :]
27+
labels_train_fold = labels[train_idxs]
28+
num_instances = [int(sum(labels_train_fold == 0)), int(sum(labels_train_fold == 1))]
29+
30+
data_test_fold = data[test_idxs, :]
31+
labels_test_fold = labels[test_idxs]
2832

2933
with tf.Graph().as_default() as graph:
3034

31-
model = ExperimentModel(fisher, num_features, num_instances, None, data_fold)
35+
model = ExperimentModel(fisher, num_features, num_instances, None, data_train_fold)
3236

3337
with tf.Session() as session:
3438

@@ -37,12 +41,21 @@
3741

3842
log_saver = LogSaver('logs', 'fisher_fold{}'.format(fold_id), session.graph)
3943

40-
selected_data = session.run(model.selection_wrapper.selected_features)
44+
train_selected_data = session.run(model.selection_wrapper.selected_data)
45+
test_selected_data = session.run(model.selection_wrapper.select(data_test_fold))
4146

4247
tqdm_iter = tqdm(range(num_epochs), desc='Epochs')
4348

4449
for epoch in tqdm_iter:
45-
feed_dict = {model.clf.x: selected_data, model.clf.y: labels_fold}
50+
feed_dict = {model.clf.x: train_selected_data, model.clf.y: labels_train_fold}
4651
loss, _, summary = session.run([model.clf.loss, model.clf.opt, model.clf.summary_op], feed_dict=feed_dict)
47-
log_saver.log_train(summary, epoch)
52+
53+
if epoch % eval_every == 0:
54+
summary = session.run(model.clf.summary_op, feed_dict=feed_dict)
55+
log_saver.log_train(summary, epoch)
56+
57+
feed_dict = {model.clf.x: test_selected_data, model.clf.y: labels_test_fold}
58+
summary = session.run(model.clf.summary_op, feed_dict=feed_dict)
59+
log_saver.log_test(summary, epoch)
60+
4861
tqdm_iter.set_postfix(loss='{:.2f}'.format(float(loss)), epoch=epoch)

0 commit comments

Comments
 (0)