Skip to content

Commit 9d10daf

Browse files
committed
add reconstruction and setstate of NDArray for loading pickled npy file.
1 parent aac5294 commit 9d10daf

File tree

8 files changed

+178
-53
lines changed

8 files changed

+178
-53
lines changed

src/TensorFlowNET.Core/NumPy/DtypeConstructor.cs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,50 @@ class DtypeConstructor : IObjectConstructor
1616
{
1717
public object construct(object[] args)
1818
{
19-
Console.WriteLine("DtypeConstructor");
20-
Console.WriteLine(args.Length);
21-
for (int i = 0; i < args.Length; i++)
22-
{
23-
Console.WriteLine(args[i]);
24-
}
25-
return new demo();
19+
var typeCode = (string)args[0];
20+
TF_DataType dtype;
21+
if (typeCode == "b1")
22+
dtype = np.@bool;
23+
else if (typeCode == "i1")
24+
dtype = np.@byte;
25+
else if (typeCode == "i2")
26+
dtype = np.int16;
27+
else if (typeCode == "i4")
28+
dtype = np.int32;
29+
else if (typeCode == "i8")
30+
dtype = np.int64;
31+
else if (typeCode == "u1")
32+
dtype = np.ubyte;
33+
else if (typeCode == "u2")
34+
dtype = np.uint16;
35+
else if (typeCode == "u4")
36+
dtype = np.uint32;
37+
else if (typeCode == "u8")
38+
dtype = np.uint64;
39+
else if (typeCode == "f4")
40+
dtype = np.float32;
41+
else if (typeCode == "f8")
42+
dtype = np.float64;
43+
else if (typeCode.StartsWith("S"))
44+
dtype = np.@string;
45+
else if (typeCode.StartsWith("O"))
46+
dtype = np.@object;
47+
else
48+
throw new NotSupportedException();
49+
return new TF_DataType_Warpper(dtype);
2650
}
2751
}
28-
class demo
52+
public class TF_DataType_Warpper
2953
{
30-
public void __setstate__(object[] args)
54+
TF_DataType dtype { get; set; }
55+
public TF_DataType_Warpper(TF_DataType dtype)
3156
{
32-
Console.WriteLine("demo __setstate__");
33-
Console.WriteLine(args.Length);
34-
for (int i = 0; i < args.Length; i++)
35-
{
36-
Console.WriteLine(args[i]);
37-
}
57+
this.dtype = dtype;
58+
}
59+
public void __setstate__(object[] args) { }
60+
public static implicit operator TF_DataType(TF_DataType_Warpper dtypeWarpper)
61+
{
62+
return dtypeWarpper.dtype;
3863
}
3964
}
4065
}

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,6 @@ Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, i
9999

100100
NDArray ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape)
101101
{
102-
//int data = reader.ReadByte();
103-
//Console.WriteLine(data);
104-
//Console.WriteLine(reader.ReadByte());
105102
Stream stream = reader.BaseStream;
106103
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor());
107104
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor());

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ public Array LoadMatrix(Stream stream)
2828

2929
//if (type == typeof(String))
3030
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
31-
NDArray res = ReadObjectMatrix(reader, matrix, shape);
32-
Console.WriteLine("LoadMatrix");
33-
Console.WriteLine(res.dims[0]);
34-
Console.WriteLine((int)res[0][0]);
35-
Console.WriteLine(res.dims[1]);
36-
//if (type == typeof(Object))
37-
//{
38-
39-
//}
40-
//else
41-
return ReadValueMatrix(reader, matrix, bytes, type, shape);
31+
32+
if (type == typeof(Object))
33+
{
34+
NDArray res = ReadObjectMatrix(reader, matrix, shape);
35+
// res = res.reconstructedNDArray;
36+
return res.reconstructedArray;
37+
}
38+
else
39+
{
40+
return ReadValueMatrix(reader, matrix, bytes, type, shape);
41+
}
4242
}
4343

4444
}
@@ -133,7 +133,7 @@ Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
133133
return typeof(Double);
134134
if (typeCode.StartsWith("S"))
135135
return typeof(String);
136-
if (typeCode == "O")
136+
if (typeCode.StartsWith("O"))
137137
return typeof(Object);
138138

139139
throw new NotSupportedException();

src/TensorFlowNET.Core/NumPy/MultiArrayConstructor.cs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Diagnostics.CodeAnalysis;
44
using System.Text;
55
using Razorvine.Pickle;
6+
using Razorvine.Pickle.Objects;
67

78
namespace Tensorflow.NumPy
89
{
@@ -17,28 +18,36 @@ public class MultiArrayConstructor : IObjectConstructor
1718
{
1819
public object construct(object[] args)
1920
{
20-
//Console.WriteLine(args.Length);
21-
//for (int i = 0; i < args.Length; i++)
22-
//{
23-
// Console.WriteLine(args[i]);
24-
//}
25-
Console.WriteLine("MultiArrayConstructor");
26-
21+
if (args.Length != 3)
22+
throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments.");
23+
24+
var types = (ClassDictConstructor)args[0];
25+
if (types.module != "numpy" || types.name != "ndarray")
26+
throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray");
27+
2728
var arg1 = (Object[])args[1];
2829
var dims = new int[arg1.Length];
2930
for (var i = 0; i < arg1.Length; i++)
3031
{
3132
dims[i] = (int)arg1[i];
3233
}
34+
var shape = new Shape(dims);
3335

34-
var dtype = TF_DataType.DtInvalid;
35-
switch (args[2])
36+
TF_DataType dtype;
37+
string identifier;
38+
if (args[2].GetType() == typeof(string))
39+
identifier = (string)args[2];
40+
else
41+
identifier = Encoding.UTF8.GetString((byte[])args[2]);
42+
switch (identifier)
3643
{
37-
case "b": dtype = TF_DataType.DtUint8Ref; break;
38-
default: throw new NotImplementedException("cannot parse" + args[2]);
44+
case "u": dtype = np.uint32; break;
45+
case "c": dtype = np.complex_; break;
46+
case "f": dtype = np.float32; break;
47+
case "b": dtype = np.@bool; break;
48+
default: throw new NotImplementedException($"Unsupported data type: {args[2]}");
3949
}
40-
return new NDArray(new Shape(dims), dtype);
41-
50+
return new NDArray(shape, dtype);
4251
}
4352
}
4453
}
Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,108 @@
1-
using System;
1+
using Newtonsoft.Json.Linq;
2+
using Serilog.Debugging;
3+
using System;
4+
using System.Collections;
25
using System.Collections.Generic;
36
using System.Text;
47

58
namespace Tensorflow.NumPy
69
{
710
public partial class NDArray
811
{
12+
public NDArray reconstructedNDArray { get; set; }
13+
public Array reconstructedArray { get; set; }
914
public void __setstate__(object[] args)
1015
{
11-
Console.WriteLine("NDArray __setstate__");
12-
Console.WriteLine(args.Length);
13-
for (int i = 0; i < args.Length; i++)
16+
if (args.Length != 5)
17+
throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments.");
18+
19+
var version = (int)args[0]; // version
20+
21+
var arg1 = (Object[])args[1];
22+
var dims = new int[arg1.Length];
23+
for (var i = 0; i < arg1.Length; i++)
24+
{
25+
dims[i] = (int)arg1[i];
26+
}
27+
var _ShapeLike = new Shape(dims); // shape
28+
29+
TF_DataType _DType_co = (TF_DataType_Warpper)args[2]; // DType
30+
31+
var F_continuous = (bool)args[3]; // F-continuous
32+
if (F_continuous)
33+
throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format.");
34+
35+
var data = args[4]; // Data
36+
/*
37+
* If we ever need another pickle format, increment the version
38+
* number. But we should still be able to handle the old versions.
39+
*/
40+
if (version < 0 || version > 4)
41+
throw new ValueError($"can't handle version {version} of numpy.dtype pickle");
42+
43+
// TODO: Implement the missing details and checks from the official Numpy C code here.
44+
// https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761
45+
46+
if (data.GetType() == typeof(ArrayList))
47+
{
48+
SetState((ArrayList)data);
49+
}
50+
else
51+
throw new NotImplementedException("");
52+
}
53+
private void SetState(ArrayList arrayList)
54+
{
55+
int ndim = 1;
56+
var subArrayList = arrayList;
57+
while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList))
58+
{
59+
subArrayList = (ArrayList)subArrayList[0];
60+
ndim += 1;
61+
}
62+
var type = subArrayList[0].GetType();
63+
if (type == typeof(int))
1464
{
15-
Console.WriteLine(args[i]);
65+
if (ndim == 1)
66+
{
67+
int[] list = (int[])arrayList.ToArray(typeof(int));
68+
Shape shape = new Shape(new int[] { arrayList.Count });
69+
reconstructedArray = list;
70+
reconstructedNDArray = new NDArray(list, shape);
71+
//SetData(new[] { new Slice() }, new NDArray(list, shape));
72+
//set_shape(shape);
73+
}
74+
if (ndim == 2)
75+
{
76+
int secondDim = 0;
77+
foreach (ArrayList subArray in arrayList)
78+
{
79+
secondDim = subArray.Count > secondDim ? subArray.Count : secondDim;
80+
}
81+
int[,] list = new int[arrayList.Count, secondDim];
82+
for (int i = 0; i < arrayList.Count; i++)
83+
{
84+
var subArray = (ArrayList?)arrayList[i];
85+
if (subArray == null)
86+
throw new NullReferenceException("");
87+
for (int j = 0; j < subArray.Count; j++)
88+
{
89+
var element = subArray[j];
90+
if (element == null)
91+
throw new NoNullAllowedException("the element of ArrayList cannot be null.");
92+
list[i,j] = (int) element;
93+
}
94+
}
95+
Shape shape = new Shape(new int[] { arrayList.Count, secondDim });
96+
reconstructedArray = list;
97+
reconstructedNDArray = new NDArray(list, shape);
98+
//SetData(new[] { new Slice() }, new NDArray(list, shape));
99+
//set_shape(shape);
100+
}
101+
if (ndim > 2)
102+
throw new NotImplementedException("can't handle ArrayList with more than two dimensions.");
16103
}
104+
else
105+
throw new NotImplementedException("");
17106
}
18107
}
19108
}

src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public class NDArrayConverter
1010
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
1111
=> nd.dtype switch
1212
{
13+
TF_DataType.TF_BOOL => Scalar<T>(*(bool*)nd.data),
1314
TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data),
1415
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
1516
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),

src/TensorFlowNET.Core/Numpy/Numpy.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public partial class np
4343
public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE;
4444
public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX;
4545
public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
46-
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
46+
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
47+
public static readonly TF_DataType @string = TF_DataType.TF_STRING;
48+
public static readonly TF_DataType @object = TF_DataType.TF_VARIANT;
4749
#endregion
4850

4951
public static double nan => double.NaN;

src/TensorFlowNET.Keras/Datasets/Imdb.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Datasets
7070
public class Imdb
7171
{
7272
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
73-
string file_name = "imdb.npz";
73+
string file_name = "simple.npz";
7474
string dest_folder = "imdb";
7575
/// <summary>
7676
/// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
@@ -128,13 +128,15 @@ public DatasetPass load_data(string path = "imdb.npz",
128128

129129
(NDArray, NDArray) LoadX(byte[] bytes)
130130
{
131-
var y = np.Load_Npz<byte[]>(bytes);
132-
return (y["x_train.npy"], y["x_test.npy"]);
131+
var y = np.Load_Npz<int[,]>(bytes);
132+
var x_train = y["x_train.npy"];
133+
var x_test = y["x_test.npy"];
134+
return (x_train, x_test);
133135
}
134136

135137
(NDArray, NDArray) LoadY(byte[] bytes)
136138
{
137-
var y = np.Load_Npz<long[]>(bytes);
139+
var y = np.Load_Npz<int[]>(bytes);
138140
return (y["y_train.npy"], y["y_test.npy"]);
139141
}
140142

0 commit comments

Comments
 (0)