Skip to content

Commit f57a6fe

Browse files
committed
optimize the time complexity of Imdb dataset loader
1 parent 28c77f5 commit f57a6fe

File tree

3 files changed

+71
-57
lines changed

3 files changed

+71
-57
lines changed

src/TensorFlowNET.Keras/Datasets/Imdb.cs

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ public DatasetPass load_data(
9494
var fileBytes = File.ReadAllBytes(path);
9595
var (x_train, x_test) = LoadX(fileBytes);
9696
var (labels_train, labels_test) = LoadY(fileBytes);
97-
x_test.astype(np.int32);
98-
labels_test.astype(np.int32);
9997

10098
var indices = np.arange<int>(len(x_train));
10199
np.random.shuffle(indices, seed);
@@ -107,100 +105,113 @@ public DatasetPass load_data(
107105
x_test = x_test[indices];
108106
labels_test = labels_test[indices];
109107

108+
var x_train_array = (int[,])x_train.ToMultiDimArray<int>();
109+
var x_test_array = (int[,])x_test.ToMultiDimArray<int>();
110+
var labels_train_array = (long[])labels_train.ToArray<long>();
111+
var labels_test_array = (long[])labels_test.ToArray<long>();
112+
110113
if (start_char != null)
111114
{
112-
int[,] new_x_train = new int[x_train.shape[0], x_train.shape[1] + 1];
113-
for (var i = 0; i < x_train.shape[0]; i++)
115+
int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1];
116+
for (var i = 0; i < x_train_array.GetLength(0); i++)
114117
{
115-
new_x_train[i, 0] = (int)start_char;
116-
for (var j = 0; j < x_train.shape[1]; j++)
118+
new_x_train_array[i, 0] = (int)start_char;
119+
for (var j = 0; j < x_train_array.GetLength(1); j++)
117120
{
118-
new_x_train[i, j + 1] = x_train[i][j];
121+
if (x_train_array[i, j] == 0)
122+
break;
123+
new_x_train_array[i, j + 1] = x_train_array[i, j];
119124
}
120125
}
121-
int[,] new_x_test = new int[x_test.shape[0], x_test.shape[1] + 1];
122-
for (var i = 0; i < x_test.shape[0]; i++)
126+
int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
127+
for (var i = 0; i < x_test_array.GetLength(0); i++)
123128
{
124-
new_x_test[i, 0] = (int)start_char;
125-
for (var j = 0; j < x_test.shape[1]; j++)
129+
new_x_test_array[i, 0] = (int)start_char;
130+
for (var j = 0; j < x_test_array.GetLength(1); j++)
126131
{
127-
new_x_test[i, j + 1] = x_test[i][j];
132+
if (x_test_array[i, j] == 0)
133+
break;
134+
new_x_test_array[i, j + 1] = x_test_array[i, j];
128135
}
129136
}
130-
x_train = new NDArray(new_x_train);
131-
x_test = new NDArray(new_x_test);
137+
x_train_array = new_x_train_array;
138+
x_test_array = new_x_test_array;
132139
}
133140
else if (index_from != 0)
134141
{
135-
for (var i = 0; i < x_train.shape[0]; i++)
142+
for (var i = 0; i < x_train_array.GetLength(0); i++)
136143
{
137-
for (var j = 0; j < x_train.shape[1]; j++)
144+
for (var j = 0; j < x_train_array.GetLength(1); j++)
138145
{
139-
if (x_train[i, j] != 0)
140-
x_train[i, j] += index_from;
146+
if (x_train_array[i, j] == 0)
147+
break;
148+
x_train_array[i, j] += index_from;
141149
}
142150
}
143-
for (var i = 0; i < x_test.shape[0]; i++)
151+
for (var i = 0; i < x_test_array.GetLength(0); i++)
144152
{
145-
for (var j = 0; j < x_test.shape[1]; j++)
153+
for (var j = 0; j < x_test_array.GetLength(1); j++)
146154
{
147-
if (x_test[i, j] != 0)
148-
x_test[i, j] += index_from;
155+
if (x_test_array[i, j] == 0)
156+
break;
157+
x_test[i, j] += index_from;
149158
}
150159
}
151160
}
152161

153-
if (maxlen != null)
162+
if (maxlen == null)
154163
{
155-
(x_train, labels_train) = data_utils._remove_long_seq((int)maxlen, x_train, labels_train);
156-
(x_test, labels_test) = data_utils._remove_long_seq((int)maxlen, x_test, labels_test);
157-
if (x_train.size == 0 || x_test.size == 0)
158-
throw new ValueError("After filtering for sequences shorter than maxlen=" +
159-
$"{maxlen}, no sequence was kept. Increase maxlen.");
164+
maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1));
160165
}
166+
(x_train, labels_train) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array);
167+
(x_test, labels_test) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array);
168+
if (x_train.size == 0 || x_test.size == 0)
169+
throw new ValueError("After filtering for sequences shorter than maxlen=" +
170+
$"{maxlen}, no sequence was kept. Increase maxlen.");
161171

162172
var xs = np.concatenate(new[] { x_train, x_test });
163173
var labels = np.concatenate(new[] { labels_train, labels_test });
174+
var xs_array = (int[,])xs.ToMultiDimArray<int>();
164175

165-
if(num_words == null)
176+
if (num_words == null)
166177
{
167178
num_words = 0;
168-
for (var i = 0; i < xs.shape[0]; i++)
169-
for (var j = 0; j < xs.shape[1]; j++)
170-
num_words = max((int)num_words, (int)xs[i][j]);
179+
for (var i = 0; i < xs_array.GetLength(0); i++)
180+
for (var j = 0; j < xs_array.GetLength(1); j++)
181+
num_words = max((int)num_words, (int)xs_array[i, j]);
171182
}
172183

173184
// by convention, use 2 as OOV word
174185
// reserve 'index_from' (=3 by default) characters:
175186
// 0 (padding), 1 (start), 2 (OOV)
176187
if (oov_char != null)
177188
{
178-
int[,] new_xs = new int[xs.shape[0], xs.shape[1]];
179-
for(var i = 0; i < xs.shape[0]; i++)
189+
int[,] new_xs_array = new int[xs_array.GetLength(0), xs_array.GetLength(1)];
190+
for (var i = 0; i < xs_array.GetLength(0); i++)
180191
{
181-
for(var j = 0; j < xs.shape[1]; j++)
192+
for (var j = 0; j < xs_array.GetLength(1); j++)
182193
{
183-
if ((int)xs[i][j] == 0 || skip_top <= (int)xs[i][j] && (int)xs[i][j] < num_words)
184-
new_xs[i, j] = (int)xs[i][j];
194+
if (xs_array[i, j] == 0 || skip_top <= xs_array[i, j] && xs_array[i, j] < num_words)
195+
new_xs_array[i, j] = xs_array[i, j];
185196
else
186-
new_xs[i, j] = (int)oov_char;
197+
new_xs_array[i, j] = (int)oov_char;
187198
}
188199
}
189-
xs = new NDArray(new_xs);
200+
xs = new NDArray(new_xs_array);
190201
}
191202
else
192203
{
193-
int[,] new_xs = new int[xs.shape[0], xs.shape[1]];
194-
for (var i = 0; i < xs.shape[0]; i++)
204+
int[,] new_xs_array = new int[xs_array.GetLength(0), xs_array.GetLength(1)];
205+
for (var i = 0; i < xs_array.GetLength(0); i++)
195206
{
196207
int k = 0;
197-
for (var j = 0; j < xs.shape[1]; j++)
208+
for (var j = 0; j < xs_array.GetLength(1); j++)
198209
{
199-
if ((int)xs[i][j] == 0 || skip_top <= (int)xs[i][j] && (int)xs[i][j] < num_words)
200-
new_xs[i, k++] = (int)xs[i][j];
210+
if (xs_array[i, j] == 0 || skip_top <= xs_array[i, j] && xs_array[i, j] < num_words)
211+
new_xs_array[i, k++] = xs_array[i, j];
201212
}
202213
}
203-
xs = new NDArray(new_xs);
214+
xs = new NDArray(new_xs_array);
204215
}
205216

206217
var idx = len(x_train);

src/TensorFlowNET.Keras/Utils/data_utils.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,25 @@ public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArr
5454
5555
*/
5656
List<int[]> new_seq = new List<int[]>();
57-
List<int> new_label = new List<int>();
57+
List<long> new_label = new List<long>();
5858

59-
for (var i = 0; i < seq.shape[0]; i++)
59+
var seq_array = (int[,])seq.ToMultiDimArray<int>();
60+
var label_array = (long[])label.ToArray<long>();
61+
for (var i = 0; i < seq_array.GetLength(0); i++)
6062
{
61-
if (maxlen < seq.shape[1] && seq[i][maxlen] != 0)
63+
if (maxlen < seq_array.GetLength(1) && seq_array[i,maxlen] != 0)
6264
continue;
6365
int[] sentence = new int[maxlen];
64-
for (var j = 0; j < maxlen && j < seq.shape[1]; j++)
66+
for (var j = 0; j < maxlen && j < seq_array.GetLength(1); j++)
6567
{
66-
sentence[j] = seq[i, j];
68+
sentence[j] = seq_array[i, j];
6769
}
6870
new_seq.Add(sentence);
69-
new_label.Add(label[i]);
71+
new_label.Add(label_array[i]);
7072
}
7173

7274
int[,] new_seq_array = new int[new_seq.Count, maxlen];
73-
int[] new_label_array = new int[new_label.Count];
75+
long[] new_label_array = new long[new_label.Count];
7476

7577
for (var i = 0; i < new_seq.Count; i++)
7678
{

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ public void GetData()
204204
{
205205
var vocab_size = 20000; // Only consider the top 20k words
206206
var maxlen = 200; // Only consider the first 200 words of each movie review
207-
var dataset = keras.datasets.imdb.load_data(num_words: vocab_size);
207+
var dataset = keras.datasets.imdb.load_data(num_words: vocab_size, maxlen: maxlen);
208208
var x_train = dataset.Train.Item1;
209209
var y_train = dataset.Train.Item2;
210210
var x_val = dataset.Test.Item1;
@@ -217,16 +217,17 @@ public void GetData()
217217
}
218218
IEnumerable<int[]> RemoveZeros(NDArray data)
219219
{
220+
var data_array = (int[,])data.ToMultiDimArray<int>();
220221
List<int[]> new_data = new List<int[]>();
221-
for (var i = 0; i < data.shape[0]; i++)
222+
for (var i = 0; i < data_array.GetLength(0); i++)
222223
{
223224
List<int> new_array = new List<int>();
224-
for (var j = 0; j < data.shape[1]; j++)
225+
for (var j = 0; j < data_array.GetLength(1); j++)
225226
{
226-
if (data[i][j] == 0)
227+
if (data_array[i, j] == 0)
227228
break;
228229
else
229-
new_array.Add((int)data[i][j]);
230+
new_array.Add(data_array[i, j]);
230231
}
231232
new_data.Add(new_array.ToArray());
232233
}

0 commit comments

Comments
 (0)