Skip to content

Commit 28c77f5

Browse files
committed
implement Imdb dataset loader
1 parent 10f6819 commit 28c77f5

File tree

4 files changed

+198
-67
lines changed

4 files changed

+198
-67
lines changed

src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ public class RandomizedImpl
1414
public NDArray permutation(NDArray x) => new NDArray(random_ops.random_shuffle(x));
1515

1616
[AutoNumPy]
17-
public void shuffle(NDArray x)
17+
public void shuffle(NDArray x, int? seed = null)
1818
{
19-
var y = random_ops.random_shuffle(x);
19+
var y = random_ops.random_shuffle(x, seed);
2020
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize);
2121
}
2222

src/TensorFlowNET.Keras/Datasets/Imdb.cs

Lines changed: 125 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
using System.IO;
44
using System.Text;
55
using Tensorflow.Keras.Utils;
6-
using Tensorflow.NumPy;
7-
using System.Linq;
86

97
namespace Tensorflow.Keras.Datasets
108
{
@@ -41,14 +39,14 @@ namespace Tensorflow.Keras.Datasets
4139
/// `skip_top` limits will be replaced with this character.
4240
/// index_from: int. Index actual words with this index and higher.
4341
/// Returns:
44-
/// Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
42+
/// Tuple of Numpy arrays: `(x_train, labels_train), (x_test, labels_test)`.
4543
///
4644
/// ** x_train, x_test**: lists of sequences, which are lists of indexes
4745
/// (integers). If the num_words argument was specific, the maximum
4846
/// possible index value is `num_words - 1`. If the `maxlen` argument was
4947
/// specified, the largest possible sequence length is `maxlen`.
5048
///
51-
/// ** y_train, y_test**: lists of integer labels(1 or 0).
49+
/// ** labels_train, labels_test**: lists of integer labels(1 or 0).
5250
///
5351
/// Raises:
5452
/// ValueError: in case `maxlen` is so low
@@ -63,7 +61,6 @@ namespace Tensorflow.Keras.Datasets
6361
public class Imdb
6462
{
6563
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
66-
string file_name = "imdb.npz";
6764
string dest_folder = "imdb";
6865

6966
/// <summary>
@@ -78,43 +75,139 @@ public class Imdb
7875
/// <param name="oov_char"></param>
7976
/// <param name="index_from"></param>
8077
/// <returns></returns>
81-
public DatasetPass load_data(string? path = "imdb.npz",
82-
int num_words = -1,
78+
public DatasetPass load_data(
79+
string path = "imdb.npz",
80+
int? num_words = null,
8381
int skip_top = 0,
84-
int maxlen = -1,
82+
int? maxlen = null,
8583
int seed = 113,
86-
int start_char = 1,
87-
int oov_char= 2,
84+
int? start_char = 1,
85+
int? oov_char = 2,
8886
int index_from = 3)
8987
{
90-
if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned.");
91-
92-
var dst = path ?? Download();
93-
var fileBytes = File.ReadAllBytes(Path.Combine(dst, file_name));
94-
var (y_train, y_test) = LoadY(fileBytes);
88+
path = data_utils.get_file(
89+
path,
90+
origin: Path.Combine(origin_folder, "imdb.npz"),
91+
file_hash: "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
92+
);
93+
path = Path.Combine(path, "imdb.npz");
94+
var fileBytes = File.ReadAllBytes(path);
9595
var (x_train, x_test) = LoadX(fileBytes);
96-
97-
/*var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
98-
var x_train_string = new string[lines.Length];
99-
var y_train = np.zeros(new int[] { lines.Length }, np.int64);
100-
for (int i = 0; i < lines.Length; i++)
96+
var (labels_train, labels_test) = LoadY(fileBytes);
97+
x_test.astype(np.int32);
98+
labels_test.astype(np.int32);
99+
100+
var indices = np.arange<int>(len(x_train));
101+
np.random.shuffle(indices, seed);
102+
x_train = x_train[indices];
103+
labels_train = labels_train[indices];
104+
105+
indices = np.arange<int>(len(x_test));
106+
np.random.shuffle(indices, seed);
107+
x_test = x_test[indices];
108+
labels_test = labels_test[indices];
109+
110+
if (start_char != null)
111+
{
112+
int[,] new_x_train = new int[x_train.shape[0], x_train.shape[1] + 1];
113+
for (var i = 0; i < x_train.shape[0]; i++)
114+
{
115+
new_x_train[i, 0] = (int)start_char;
116+
for (var j = 0; j < x_train.shape[1]; j++)
117+
{
118+
new_x_train[i, j + 1] = x_train[i][j];
119+
}
120+
}
121+
int[,] new_x_test = new int[x_test.shape[0], x_test.shape[1] + 1];
122+
for (var i = 0; i < x_test.shape[0]; i++)
123+
{
124+
new_x_test[i, 0] = (int)start_char;
125+
for (var j = 0; j < x_test.shape[1]; j++)
126+
{
127+
new_x_test[i, j + 1] = x_test[i][j];
128+
}
129+
}
130+
x_train = new NDArray(new_x_train);
131+
x_test = new NDArray(new_x_test);
132+
}
133+
else if (index_from != 0)
134+
{
135+
for (var i = 0; i < x_train.shape[0]; i++)
136+
{
137+
for (var j = 0; j < x_train.shape[1]; j++)
138+
{
139+
if (x_train[i, j] != 0)
140+
x_train[i, j] += index_from;
141+
}
142+
}
143+
for (var i = 0; i < x_test.shape[0]; i++)
144+
{
145+
for (var j = 0; j < x_test.shape[1]; j++)
146+
{
147+
if (x_test[i, j] != 0)
148+
x_test[i, j] += index_from;
149+
}
150+
}
151+
}
152+
153+
if (maxlen != null)
101154
{
102-
y_train[i] = long.Parse(lines[i].Substring(0, 1));
103-
x_train_string[i] = lines[i].Substring(2);
155+
(x_train, labels_train) = data_utils._remove_long_seq((int)maxlen, x_train, labels_train);
156+
(x_test, labels_test) = data_utils._remove_long_seq((int)maxlen, x_test, labels_test);
157+
if (x_train.size == 0 || x_test.size == 0)
158+
throw new ValueError("After filtering for sequences shorter than maxlen=" +
159+
$"{maxlen}, no sequence was kept. Increase maxlen.");
104160
}
105161

106-
var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen);
162+
var xs = np.concatenate(new[] { x_train, x_test });
163+
var labels = np.concatenate(new[] { labels_train, labels_test });
107164

108-
lines = File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
109-
var x_test_string = new string[lines.Length];
110-
var y_test = np.zeros(new int[] { lines.Length }, np.int64);
111-
for (int i = 0; i < lines.Length; i++)
165+
if(num_words == null)
112166
{
113-
y_test[i] = long.Parse(lines[i].Substring(0, 1));
114-
x_test_string[i] = lines[i].Substring(2);
167+
num_words = 0;
168+
for (var i = 0; i < xs.shape[0]; i++)
169+
for (var j = 0; j < xs.shape[1]; j++)
170+
num_words = max((int)num_words, (int)xs[i][j]);
115171
}
116172

117-
var x_test = np.array(x_test_string);*/
173+
// by convention, use 2 as OOV word
174+
// reserve 'index_from' (=3 by default) characters:
175+
// 0 (padding), 1 (start), 2 (OOV)
176+
if (oov_char != null)
177+
{
178+
int[,] new_xs = new int[xs.shape[0], xs.shape[1]];
179+
for(var i = 0; i < xs.shape[0]; i++)
180+
{
181+
for(var j = 0; j < xs.shape[1]; j++)
182+
{
183+
if ((int)xs[i][j] == 0 || skip_top <= (int)xs[i][j] && (int)xs[i][j] < num_words)
184+
new_xs[i, j] = (int)xs[i][j];
185+
else
186+
new_xs[i, j] = (int)oov_char;
187+
}
188+
}
189+
xs = new NDArray(new_xs);
190+
}
191+
else
192+
{
193+
int[,] new_xs = new int[xs.shape[0], xs.shape[1]];
194+
for (var i = 0; i < xs.shape[0]; i++)
195+
{
196+
int k = 0;
197+
for (var j = 0; j < xs.shape[1]; j++)
198+
{
199+
if ((int)xs[i][j] == 0 || skip_top <= (int)xs[i][j] && (int)xs[i][j] < num_words)
200+
new_xs[i, k++] = (int)xs[i][j];
201+
}
202+
}
203+
xs = new NDArray(new_xs);
204+
}
205+
206+
var idx = len(x_train);
207+
x_train = xs[$"0:{idx}"];
208+
x_test = xs[$"{idx}:"];
209+
var y_train = labels[$"0:{idx}"];
210+
var y_test = labels[$"{idx}:"];
118211

119212
return new DatasetPass
120213
{
@@ -125,43 +218,14 @@ public DatasetPass load_data(string? path = "imdb.npz",
125218

126219
(NDArray, NDArray) LoadX(byte[] bytes)
127220
{
128-
var y = np.Load_Npz<int[,]>(bytes);
129-
return (y["x_train.npy"], y["x_test.npy"]);
221+
var x = np.Load_Npz<int[,]>(bytes);
222+
return (x["x_train.npy"], x["x_test.npy"]);
130223
}
131224

132225
(NDArray, NDArray) LoadY(byte[] bytes)
133226
{
134227
var y = np.Load_Npz<long[]>(bytes);
135228
return (y["y_train.npy"], y["y_test.npy"]);
136229
}
137-
138-
string Download()
139-
{
140-
var dst = Path.Combine(Path.GetTempPath(), dest_folder);
141-
Directory.CreateDirectory(dst);
142-
143-
Web.Download(origin_folder + file_name, dst, file_name);
144-
145-
return dst;
146-
// return Path.Combine(dst, file_name);
147-
}
148-
149-
protected IEnumerable<int[]> PraseData(string[] x)
150-
{
151-
var data_list = new List<int[]>();
152-
for (int i = 0; i < len(x); i++)
153-
{
154-
var list_string = x[i];
155-
var cleaned_list_string = list_string.Replace("[", "").Replace("]", "").Replace(" ", "");
156-
string[] number_strings = cleaned_list_string.Split(',');
157-
int[] numbers = new int[number_strings.Length];
158-
for (int j = 0; j < number_strings.Length; j++)
159-
{
160-
numbers[j] = int.Parse(number_strings[j]);
161-
}
162-
data_list.Add(numbers);
163-
}
164-
return data_list;
165-
}
166230
}
167231
}

src/TensorFlowNET.Keras/Utils/data_utils.cs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,52 @@ public static string get_file(string fname, string origin,
3939

4040
return datadir;
4141
}
42+
43+
public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArray label)
44+
{
45+
/*Removes sequences that exceed the maximum length.
46+
47+
Args:
48+
maxlen: Int, maximum length of the output sequences.
49+
seq: List of lists, where each sublist is a sequence.
50+
label: List where each element is an integer.
51+
52+
Returns:
53+
new_seq, new_label: shortened lists for `seq` and `label`.
54+
55+
*/
56+
List<int[]> new_seq = new List<int[]>();
57+
List<int> new_label = new List<int>();
58+
59+
for (var i = 0; i < seq.shape[0]; i++)
60+
{
61+
if (maxlen < seq.shape[1] && seq[i][maxlen] != 0)
62+
continue;
63+
int[] sentence = new int[maxlen];
64+
for (var j = 0; j < maxlen && j < seq.shape[1]; j++)
65+
{
66+
sentence[j] = seq[i, j];
67+
}
68+
new_seq.Add(sentence);
69+
new_label.Add(label[i]);
70+
}
71+
72+
int[,] new_seq_array = new int[new_seq.Count, maxlen];
73+
int[] new_label_array = new int[new_label.Count];
74+
75+
for (var i = 0; i < new_seq.Count; i++)
76+
{
77+
for (var j = 0; j < maxlen; j++)
78+
{
79+
new_seq_array[i, j] = new_seq[i][j];
80+
}
81+
}
82+
83+
for (var i = 0; i < new_label.Count; i++)
84+
{
85+
new_label_array[i] = new_label[i];
86+
}
87+
return (new_seq_array, new_label_array);
88+
}
4289
}
4390
}

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Collections.Generic;
34
using System.Linq;
5+
using Tensorflow.NumPy;
46
using static Tensorflow.Binding;
57
using static Tensorflow.KerasApi;
68

@@ -207,10 +209,28 @@ public void GetData()
207209
var y_train = dataset.Train.Item2;
208210
var x_val = dataset.Test.Item1;
209211
var y_val = dataset.Test.Item2;
210-
print(len(x_train) + "Training sequences");
211-
print(len(x_val) + "Validation sequences");
212-
//x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_train, maxlen: maxlen);
213-
//x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_val, maxlen: maxlen);
212+
213+
x_train = keras.preprocessing.sequence.pad_sequences(RemoveZeros(x_train), maxlen: maxlen);
214+
x_val = keras.preprocessing.sequence.pad_sequences(RemoveZeros(x_val), maxlen: maxlen);
215+
print(len(x_train) + " Training sequences");
216+
print(len(x_val) + " Validation sequences");
217+
}
218+
IEnumerable<int[]> RemoveZeros(NDArray data)
219+
{
220+
List<int[]> new_data = new List<int[]>();
221+
for (var i = 0; i < data.shape[0]; i++)
222+
{
223+
List<int> new_array = new List<int>();
224+
for (var j = 0; j < data.shape[1]; j++)
225+
{
226+
if (data[i][j] == 0)
227+
break;
228+
else
229+
new_array.Add((int)data[i][j]);
230+
}
231+
new_data.Add(new_array.ToArray());
232+
}
233+
return new_data;
214234
}
215235
}
216236
}

0 commit comments

Comments
 (0)