33using System . IO ;
44using System . Text ;
55using Tensorflow . Keras . Utils ;
6- using Tensorflow . NumPy ;
7- using System . Linq ;
86
97namespace Tensorflow . Keras . Datasets
108{
@@ -41,14 +39,14 @@ namespace Tensorflow.Keras.Datasets
4139 /// `skip_top` limits will be replaced with this character.
4240 /// index_from: int. Index actual words with this index and higher.
4341 /// Returns:
44- /// Tuple of Numpy arrays: `(x_train, y_train ), (x_test, y_test )`.
42+ /// Tuple of Numpy arrays: `(x_train, labels_train ), (x_test, labels_test )`.
4543 ///
4644 /// ** x_train, x_test**: lists of sequences, which are lists of indexes
4745 /// (integers). If the num_words argument was specific, the maximum
4846 /// possible index value is `num_words - 1`. If the `maxlen` argument was
4947 /// specified, the largest possible sequence length is `maxlen`.
5048 ///
51- /// ** y_train, y_test **: lists of integer labels(1 or 0).
49+ /// ** labels_train, labels_test **: lists of integer labels(1 or 0).
5250 ///
5351 /// Raises:
5452 /// ValueError: in case `maxlen` is so low
@@ -63,7 +61,6 @@ namespace Tensorflow.Keras.Datasets
6361 public class Imdb
6462 {
6563 string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" ;
66- string file_name = "imdb.npz" ;
6764 string dest_folder = "imdb" ;
6865
6966 /// <summary>
@@ -78,43 +75,139 @@ public class Imdb
7875 /// <param name="oov_char"></param>
7976 /// <param name="index_from"></param>
8077 /// <returns></returns>
81- public DatasetPass load_data ( string ? path = "imdb.npz" ,
82- int num_words = - 1 ,
78+ public DatasetPass load_data (
79+ string path = "imdb.npz" ,
80+ int ? num_words = null ,
8381 int skip_top = 0 ,
84- int maxlen = - 1 ,
82+ int ? maxlen = null ,
8583 int seed = 113 ,
86- int start_char = 1 ,
87- int oov_char = 2 ,
84+ int ? start_char = 1 ,
85+ int ? oov_char = 2 ,
8886 int index_from = 3 )
8987 {
90- if ( maxlen == - 1 ) throw new InvalidArgumentError ( "maxlen must be assigned." ) ;
91-
92- var dst = path ?? Download ( ) ;
93- var fileBytes = File . ReadAllBytes ( Path . Combine ( dst , file_name ) ) ;
94- var ( y_train , y_test ) = LoadY ( fileBytes ) ;
88+ path = data_utils . get_file (
89+ path ,
90+ origin : Path . Combine ( origin_folder , "imdb.npz" ) ,
91+ file_hash : "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
92+ ) ;
93+ path = Path . Combine ( path , "imdb.npz" ) ;
94+ var fileBytes = File . ReadAllBytes ( path ) ;
9595 var ( x_train , x_test ) = LoadX ( fileBytes ) ;
96-
97- /*var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
98- var x_train_string = new string[lines.Length];
99- var y_train = np.zeros(new int[] { lines.Length }, np.int64);
100- for (int i = 0; i < lines.Length; i++)
96+ var ( labels_train , labels_test ) = LoadY ( fileBytes ) ;
97+ x_test . astype ( np . int32 ) ;
98+ labels_test . astype ( np . int32 ) ;
99+
100+ var indices = np . arange < int > ( len ( x_train ) ) ;
101+ np . random . shuffle ( indices , seed ) ;
102+ x_train = x_train [ indices ] ;
103+ labels_train = labels_train [ indices ] ;
104+
105+ indices = np . arange < int > ( len ( x_test ) ) ;
106+ np . random . shuffle ( indices , seed ) ;
107+ x_test = x_test [ indices ] ;
108+ labels_test = labels_test [ indices ] ;
109+
110+ if ( start_char != null )
111+ {
112+ int [ , ] new_x_train = new int [ x_train . shape [ 0 ] , x_train . shape [ 1 ] + 1 ] ;
113+ for ( var i = 0 ; i < x_train . shape [ 0 ] ; i ++ )
114+ {
115+ new_x_train [ i , 0 ] = ( int ) start_char ;
116+ for ( var j = 0 ; j < x_train . shape [ 1 ] ; j ++ )
117+ {
118+ new_x_train [ i , j + 1 ] = x_train [ i ] [ j ] ;
119+ }
120+ }
121+ int [ , ] new_x_test = new int [ x_test . shape [ 0 ] , x_test . shape [ 1 ] + 1 ] ;
122+ for ( var i = 0 ; i < x_test . shape [ 0 ] ; i ++ )
123+ {
124+ new_x_test [ i , 0 ] = ( int ) start_char ;
125+ for ( var j = 0 ; j < x_test . shape [ 1 ] ; j ++ )
126+ {
127+ new_x_test [ i , j + 1 ] = x_test [ i ] [ j ] ;
128+ }
129+ }
130+ x_train = new NDArray ( new_x_train ) ;
131+ x_test = new NDArray ( new_x_test ) ;
132+ }
133+ else if ( index_from != 0 )
134+ {
135+ for ( var i = 0 ; i < x_train . shape [ 0 ] ; i ++ )
136+ {
137+ for ( var j = 0 ; j < x_train . shape [ 1 ] ; j ++ )
138+ {
139+ if ( x_train [ i , j ] != 0 )
140+ x_train [ i , j ] += index_from ;
141+ }
142+ }
143+ for ( var i = 0 ; i < x_test . shape [ 0 ] ; i ++ )
144+ {
145+ for ( var j = 0 ; j < x_test . shape [ 1 ] ; j ++ )
146+ {
147+ if ( x_test [ i , j ] != 0 )
148+ x_test [ i , j ] += index_from ;
149+ }
150+ }
151+ }
152+
153+ if ( maxlen != null )
101154 {
102- y_train[i] = long.Parse(lines[i].Substring(0, 1));
103- x_train_string[i] = lines[i].Substring(2);
155+ ( x_train , labels_train ) = data_utils . _remove_long_seq ( ( int ) maxlen , x_train , labels_train ) ;
156+ ( x_test , labels_test ) = data_utils . _remove_long_seq ( ( int ) maxlen , x_test , labels_test ) ;
157+ if ( x_train . size == 0 || x_test . size == 0 )
158+ throw new ValueError ( "After filtering for sequences shorter than maxlen=" +
159+ $ "{ maxlen } , no sequence was kept. Increase maxlen.") ;
104160 }
105161
106- var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen);
162+ var xs = np . concatenate ( new [ ] { x_train , x_test } ) ;
163+ var labels = np . concatenate ( new [ ] { labels_train , labels_test } ) ;
107164
108- lines = File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
109- var x_test_string = new string[lines.Length];
110- var y_test = np.zeros(new int[] { lines.Length }, np.int64);
111- for (int i = 0; i < lines.Length; i++)
165+ if ( num_words == null )
112166 {
113- y_test[i] = long.Parse(lines[i].Substring(0, 1));
114- x_test_string[i] = lines[i].Substring(2);
167+ num_words = 0 ;
168+ for ( var i = 0 ; i < xs . shape [ 0 ] ; i ++ )
169+ for ( var j = 0 ; j < xs . shape [ 1 ] ; j ++ )
170+ num_words = max ( ( int ) num_words , ( int ) xs [ i ] [ j ] ) ;
115171 }
116172
117- var x_test = np.array(x_test_string);*/
173+ // by convention, use 2 as OOV word
174+ // reserve 'index_from' (=3 by default) characters:
175+ // 0 (padding), 1 (start), 2 (OOV)
176+ if ( oov_char != null )
177+ {
178+ int [ , ] new_xs = new int [ xs . shape [ 0 ] , xs . shape [ 1 ] ] ;
179+ for ( var i = 0 ; i < xs . shape [ 0 ] ; i ++ )
180+ {
181+ for ( var j = 0 ; j < xs . shape [ 1 ] ; j ++ )
182+ {
183+ if ( ( int ) xs [ i ] [ j ] == 0 || skip_top <= ( int ) xs [ i ] [ j ] && ( int ) xs [ i ] [ j ] < num_words )
184+ new_xs [ i , j ] = ( int ) xs [ i ] [ j ] ;
185+ else
186+ new_xs [ i , j ] = ( int ) oov_char ;
187+ }
188+ }
189+ xs = new NDArray ( new_xs ) ;
190+ }
191+ else
192+ {
193+ int [ , ] new_xs = new int [ xs . shape [ 0 ] , xs . shape [ 1 ] ] ;
194+ for ( var i = 0 ; i < xs . shape [ 0 ] ; i ++ )
195+ {
196+ int k = 0 ;
197+ for ( var j = 0 ; j < xs . shape [ 1 ] ; j ++ )
198+ {
199+ if ( ( int ) xs [ i ] [ j ] == 0 || skip_top <= ( int ) xs [ i ] [ j ] && ( int ) xs [ i ] [ j ] < num_words )
200+ new_xs [ i , k ++ ] = ( int ) xs [ i ] [ j ] ;
201+ }
202+ }
203+ xs = new NDArray ( new_xs ) ;
204+ }
205+
206+ var idx = len ( x_train ) ;
207+ x_train = xs [ $ "0:{ idx } "] ;
208+ x_test = xs [ $ "{ idx } :"] ;
209+ var y_train = labels [ $ "0:{ idx } "] ;
210+ var y_test = labels [ $ "{ idx } :"] ;
118211
119212 return new DatasetPass
120213 {
@@ -125,43 +218,14 @@ public DatasetPass load_data(string? path = "imdb.npz",
125218
126219 ( NDArray , NDArray ) LoadX ( byte [ ] bytes )
127220 {
128- var y = np . Load_Npz < int [ , ] > ( bytes ) ;
129- return ( y [ "x_train.npy" ] , y [ "x_test.npy" ] ) ;
221+ var x = np . Load_Npz < int [ , ] > ( bytes ) ;
222+ return ( x [ "x_train.npy" ] , x [ "x_test.npy" ] ) ;
130223 }
131224
132225 ( NDArray , NDArray ) LoadY ( byte [ ] bytes )
133226 {
134227 var y = np . Load_Npz < long [ ] > ( bytes ) ;
135228 return ( y [ "y_train.npy" ] , y [ "y_test.npy" ] ) ;
136229 }
137-
138- string Download ( )
139- {
140- var dst = Path . Combine ( Path . GetTempPath ( ) , dest_folder ) ;
141- Directory . CreateDirectory ( dst ) ;
142-
143- Web . Download ( origin_folder + file_name , dst , file_name ) ;
144-
145- return dst ;
146- // return Path.Combine(dst, file_name);
147- }
148-
149- protected IEnumerable < int [ ] > PraseData ( string [ ] x )
150- {
151- var data_list = new List < int [ ] > ( ) ;
152- for ( int i = 0 ; i < len ( x ) ; i ++ )
153- {
154- var list_string = x [ i ] ;
155- var cleaned_list_string = list_string . Replace ( "[" , "" ) . Replace ( "]" , "" ) . Replace ( " " , "" ) ;
156- string [ ] number_strings = cleaned_list_string . Split ( ',' ) ;
157- int [ ] numbers = new int [ number_strings . Length ] ;
158- for ( int j = 0 ; j < number_strings . Length ; j ++ )
159- {
160- numbers [ j ] = int . Parse ( number_strings [ j ] ) ;
161- }
162- data_list . Add ( numbers ) ;
163- }
164- return data_list ;
165- }
166230 }
167231}
0 commit comments