Skip to content

Commit aac5294

Browse files
committed
init pickle support to np.load object type of npy
1 parent e1ece66 commit aac5294

File tree

8 files changed

+215
-9
lines changed

8 files changed

+215
-9
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.CodeAnalysis;
4+
using System.Text;
5+
using Razorvine.Pickle;
6+
7+
namespace Tensorflow.NumPy
8+
{
9+
/// <summary>
10+
///
11+
/// </summary>
12+
[SuppressMessage("ReSharper", "InconsistentNaming")]
13+
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
14+
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
15+
class DtypeConstructor : IObjectConstructor
16+
{
17+
public object construct(object[] args)
18+
{
19+
Console.WriteLine("DtypeConstructor");
20+
Console.WriteLine(args.Length);
21+
for (int i = 0; i < args.Length; i++)
22+
{
23+
Console.WriteLine(args[i]);
24+
}
25+
return new demo();
26+
}
27+
}
28+
class demo
29+
{
30+
public void __setstate__(object[] args)
31+
{
32+
Console.WriteLine("demo __setstate__");
33+
Console.WriteLine(args.Length);
34+
for (int i = 0; i < args.Length; i++)
35+
{
36+
Console.WriteLine(args[i]);
37+
}
38+
}
39+
}
40+
}

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq;
55
using System.Text;
66
using Tensorflow.Util;
7+
using Razorvine.Pickle;
78
using static Tensorflow.Binding;
89

910
namespace Tensorflow.NumPy
@@ -93,10 +94,25 @@ Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, i
9394

9495
var buffer = reader.ReadBytes(bytes * total);
9596
System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length);
96-
9797
return matrix;
9898
}
9999

100+
NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape)
101+
{
102+
//int data = reader.ReadByte();
103+
//Console.WriteLine(data);
104+
//Console.WriteLine(reader.ReadByte());
105+
Stream stream = reader.BaseStream;
106+
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor());
107+
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor());
108+
109+
var unpickler = new Unpickler();
110+
111+
NDArray result = (NDArray) unpickler.load(stream);
112+
Console.WriteLine(result.dims);
113+
return result;
114+
}
115+
100116
public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false)
101117
{
102118
var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse);

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,28 @@ public Array LoadMatrix(Stream stream)
2727
Array matrix = Array.CreateInstance(type, shape);
2828

2929
//if (type == typeof(String))
30-
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
30+
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
31+
NDArray res = ReadObjectMatrix(reader, matrix, shape);
32+
Console.WriteLine("LoadMatrix");
33+
Console.WriteLine(res.dims[0]);
34+
Console.WriteLine((int)res[0][0]);
35+
Console.WriteLine(res.dims[1]);
36+
//if (type == typeof(Object))
37+
//{
38+
39+
//}
40+
//else
3141
return ReadValueMatrix(reader, matrix, bytes, type, shape);
3242
}
43+
3344
}
3445

3546
public T Load<T>(Stream stream)
3647
where T : class,
3748
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
3849
{
3950
// if (typeof(T).IsArray && (typeof(T).GetElementType().IsArray || typeof(T).GetElementType() == typeof(string)))
40-
// return LoadJagged(stream) as T;
51+
// return LoadJagged(stream) as T;
4152
return LoadMatrix(stream) as T;
4253
}
4354

@@ -48,7 +59,7 @@ bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape
4859
shape = null;
4960

5061
// The first 6 bytes are a magic string: exactly "x93NUMPY"
51-
if (reader.ReadChar() != 63) return false;
62+
if (reader.ReadByte() != 0x93) return false;
5263
if (reader.ReadChar() != 'N') return false;
5364
if (reader.ReadChar() != 'U') return false;
5465
if (reader.ReadChar() != 'M') return false;
@@ -64,6 +75,7 @@ bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape
6475
ushort len = reader.ReadUInt16();
6576

6677
string header = new String(reader.ReadChars(len));
78+
Console.WriteLine(header);
6779
string mark = "'descr': '";
6880
int s = header.IndexOf(mark) + mark.Length;
6981
int e = header.IndexOf("'", s + 1);
@@ -93,7 +105,7 @@ bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape
93105
Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
94106
{
95107
isLittleEndian = IsLittleEndian(dtype);
96-
bytes = Int32.Parse(dtype.Substring(2));
108+
bytes = dtype.Length > 2 ? Int32.Parse(dtype.Substring(2)) : 0;
97109

98110
string typeCode = dtype.Substring(1);
99111

@@ -121,6 +133,8 @@ Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
121133
return typeof(Double);
122134
if (typeCode.StartsWith("S"))
123135
return typeof(String);
136+
if (typeCode == "O")
137+
return typeof(Object);
124138

125139
throw new NotSupportedException();
126140
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.CodeAnalysis;
4+
using System.Text;
5+
using Razorvine.Pickle;
6+
7+
namespace Tensorflow.NumPy
8+
{
9+
/// <summary>
10+
/// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if
11+
/// the objects are ints, etc.
12+
/// </summary>
13+
[SuppressMessage("ReSharper", "InconsistentNaming")]
14+
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
15+
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
16+
public class MultiArrayConstructor : IObjectConstructor
17+
{
18+
public object construct(object[] args)
19+
{
20+
//Console.WriteLine(args.Length);
21+
//for (int i = 0; i < args.Length; i++)
22+
//{
23+
// Console.WriteLine(args[i]);
24+
//}
25+
Console.WriteLine("MultiArrayConstructor");
26+
27+
var arg1 = (Object[])args[1];
28+
var dims = new int[arg1.Length];
29+
for (var i = 0; i < arg1.Length; i++)
30+
{
31+
dims[i] = (int)arg1[i];
32+
}
33+
34+
var dtype = TF_DataType.DtInvalid;
35+
switch (args[2])
36+
{
37+
case "b": dtype = TF_DataType.DtUint8Ref; break;
38+
default: throw new NotImplementedException("cannot parse" + args[2]);
39+
}
40+
return new NDArray(new Shape(dims), dtype);
41+
42+
}
43+
}
44+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.NumPy
6+
{
7+
public partial class NDArray
8+
{
9+
public void __setstate__(object[] args)
10+
{
11+
Console.WriteLine("NDArray __setstate__");
12+
Console.WriteLine(args.Length);
13+
for (int i = 0; i < args.Length; i++)
14+
{
15+
Console.WriteLine(args[i]);
16+
}
17+
}
18+
}
19+
}

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ https://tensorflownet.readthedocs.io</Description>
112112
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
113113
<PackageReference Include="OneOf" Version="3.0.223" />
114114
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
115+
<PackageReference Include="Razorvine.Pickle" Version="1.4.0" />
115116
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
116117
</ItemGroup>
117118
</Project>

src/TensorFlowNET.Keras/Datasets/Imdb.cs

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,73 @@
55
using Tensorflow.Keras.Utils;
66
using Tensorflow.NumPy;
77
using System.Linq;
8+
using Google.Protobuf.Collections;
9+
using Microsoft.VisualBasic;
10+
using OneOf.Types;
11+
using static HDF.PInvoke.H5;
12+
using System.Data;
13+
using System.Reflection.Emit;
14+
using System.Xml.Linq;
815

916
namespace Tensorflow.Keras.Datasets
1017
{
1118
/// <summary>
1219
/// This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment
1320
/// (positive/negative). Reviews have been preprocessed, and each review is
1421
/// encoded as a list of word indexes(integers).
22+
/// For convenience, words are indexed by overall frequency in the dataset,
23+
/// so that for instance the integer "3" encodes the 3rd most frequent word in
24+
/// the data.This allows for quick filtering operations such as:
25+
/// "only consider the top 10,000 most
26+
/// common words, but eliminate the top 20 most common words".
27+
/// As a convention, "0" does not stand for a specific word, but instead is used
28+
/// to encode the pad token.
29+
/// Args:
30+
/// path: where to cache the data (relative to %TEMP%/imdb/imdb.npz).
31+
/// num_words: integer or None.Words are
32+
/// ranked by how often they occur(in the training set) and only
33+
/// the `num_words` most frequent words are kept.Any less frequent word
34+
/// will appear as `oov_char` value in the sequence data.If None,
35+
/// all words are kept.Defaults to `None`.
36+
/// skip_top: skip the top N most frequently occurring words
37+
/// (which may not be informative). These words will appear as
38+
/// `oov_char` value in the dataset.When 0, no words are
39+
/// skipped. Defaults to `0`.
40+
/// maxlen: int or None.Maximum sequence length.
41+
/// Any longer sequence will be truncated. None, means no truncation.
42+
/// Defaults to `None`.
43+
/// seed: int. Seed for reproducible data shuffling.
44+
/// start_char: int. The start of a sequence will be marked with this
45+
/// character. 0 is usually the padding character. Defaults to `1`.
46+
/// oov_char: int. The out-of-vocabulary character.
47+
/// Words that were cut out because of the `num_words` or
48+
/// `skip_top` limits will be replaced with this character.
49+
/// index_from: int. Index actual words with this index and higher.
50+
/// Returns:
51+
/// Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
52+
///
53+
/// ** x_train, x_test**: lists of sequences, which are lists of indexes
54+
/// (integers). If the num_words argument was specific, the maximum
55+
/// possible index value is `num_words - 1`. If the `maxlen` argument was
56+
/// specified, the largest possible sequence length is `maxlen`.
57+
///
58+
/// ** y_train, y_test**: lists of integer labels(1 or 0).
59+
///
60+
/// Raises:
61+
/// ValueError: in case `maxlen` is so low
62+
/// that no input sequence could be kept.
63+
/// Note that the 'out of vocabulary' character is only used for
64+
/// words that were present in the training set but are not included
65+
/// because they're not making the `num_words` cut here.
66+
/// Words that were not seen in the training set but are in the test set
67+
/// have simply been skipped.
1568
/// </summary>
69+
/// """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
1670
public class Imdb
1771
{
1872
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
1973
string file_name = "imdb.npz";
2074
string dest_folder = "imdb";
21-
2275
/// <summary>
2376
/// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
2477
/// </summary>
@@ -41,8 +94,10 @@ public DatasetPass load_data(string path = "imdb.npz",
4194
int index_from = 3)
4295
{
4396
var dst = Download();
44-
45-
var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
97+
var fileBytes = File.ReadAllBytes(Path.Combine(dst, file_name));
98+
var (x_train, x_test) = LoadX(fileBytes);
99+
var (y_train, y_test) = LoadY(fileBytes);
100+
/*var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
46101
var x_train_string = new string[lines.Length];
47102
var y_train = np.zeros(new int[] { lines.Length }, np.int64);
48103
for (int i = 0; i < lines.Length; i++)
@@ -62,7 +117,7 @@ public DatasetPass load_data(string path = "imdb.npz",
62117
x_test_string[i] = lines[i].Substring(2);
63118
}
64119
65-
var x_test = np.array(x_test_string);
120+
var x_test = np.array(x_test_string);*/
66121

67122
return new DatasetPass
68123
{

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Collections.Generic;
34
using System.Linq;
45
using static Tensorflow.Binding;
6+
using static Tensorflow.KerasApi;
57

68
namespace TensorFlowNET.UnitTest.Dataset
79
{
@@ -195,5 +197,20 @@ public void Shuffle()
195197

196198
Assert.IsFalse(allEqual);
197199
}
200+
[TestMethod]
201+
public void GetData()
202+
{
203+
var vocab_size = 20000; // Only consider the top 20k words
204+
var maxlen = 200; // Only consider the first 200 words of each movie review
205+
var dataset = keras.datasets.imdb.load_data(num_words: vocab_size);
206+
var x_train = dataset.Train.Item1;
207+
var y_train = dataset.Train.Item2;
208+
var x_val = dataset.Test.Item1;
209+
var y_val = dataset.Test.Item2;
210+
print(len(x_train) + "Training sequences");
211+
print(len(x_val) + "Validation sequences");
212+
x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_train, maxlen: maxlen);
213+
x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_val, maxlen: maxlen);
214+
}
198215
}
199216
}

0 commit comments

Comments
 (0)