Skip to content

Commit c7ee230

Browse files
committed
fix ToMultiDimArray
1 parent 94601f5 commit c7ee230

File tree

5 files changed

+100
-23
lines changed

5 files changed

+100
-23
lines changed

src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,62 @@ static T Scalar<T>(long input)
3030
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
3131
_ => throw new NotImplementedException("")
3232
};
33+
34+
public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
35+
{
36+
var ret = Array.CreateInstance(typeof(T), nd.shape.as_int_list());
37+
38+
var addr = ret switch
39+
{
40+
T[] array => Addr(array),
41+
T[,] array => Addr(array),
42+
T[,,] array => Addr(array),
43+
T[,,,] array => Addr(array),
44+
T[,,,,] array => Addr(array),
45+
T[,,,,,] array => Addr(array),
46+
_ => throw new NotImplementedException("")
47+
};
48+
49+
System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize);
50+
return ret;
51+
}
52+
53+
#region multiple array
54+
static unsafe T* Addr<T>(T[] array) where T : unmanaged
55+
{
56+
fixed (T* a = &array[0])
57+
return a;
58+
}
59+
60+
static unsafe T* Addr<T>(T[,] array) where T : unmanaged
61+
{
62+
fixed (T* a = &array[0, 0])
63+
return a;
64+
}
65+
66+
static unsafe T* Addr<T>(T[,,] array) where T : unmanaged
67+
{
68+
fixed (T* a = &array[0, 0, 0])
69+
return a;
70+
}
71+
72+
static unsafe T* Addr<T>(T[,,,] array) where T : unmanaged
73+
{
74+
fixed (T* a = &array[0, 0, 0, 0])
75+
return a;
76+
}
77+
78+
static unsafe T* Addr<T>(T[,,,,] array) where T : unmanaged
79+
{
80+
fixed (T* a = &array[0, 0, 0, 0, 0])
81+
return a;
82+
}
83+
84+
static unsafe T* Addr<T>(T[,,,,,] array) where T : unmanaged
85+
{
86+
fixed (T* a = &array[0, 0, 0, 0, 0, 0])
87+
return a;
88+
}
89+
#endregion
3390
}
3491
}

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,28 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray(bool value) : base(value) { NewEagerTensorHandle(); }
12-
public NDArray(byte value) : base(value) { NewEagerTensorHandle(); }
13-
public NDArray(short value) : base(value) { NewEagerTensorHandle(); }
14-
public NDArray(int value) : base(value) { NewEagerTensorHandle(); }
15-
public NDArray(long value) : base(value) { NewEagerTensorHandle(); }
16-
public NDArray(float value) : base(value) { NewEagerTensorHandle(); }
17-
public NDArray(double value) : base(value) { NewEagerTensorHandle(); }
11+
public NDArray(bool value) : base(value) => NewEagerTensorHandle();
12+
public NDArray(byte value) : base(value) => NewEagerTensorHandle();
13+
public NDArray(short value) : base(value) => NewEagerTensorHandle();
14+
public NDArray(int value) : base(value) => NewEagerTensorHandle();
15+
public NDArray(long value) : base(value) => NewEagerTensorHandle();
16+
public NDArray(float value) : base(value) => NewEagerTensorHandle();
17+
public NDArray(double value) : base(value) => NewEagerTensorHandle();
1818

19-
public NDArray(Array value, Shape? shape = null)
20-
: base(value, shape) { NewEagerTensorHandle(); }
19+
public NDArray(Array value, Shape? shape = null) : base(value, shape)
20+
=> NewEagerTensorHandle();
2121

22-
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
23-
: base(shape, dtype: dtype) { NewEagerTensorHandle(); }
22+
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) : base(shape, dtype: dtype)
23+
=> NewEagerTensorHandle();
2424

25-
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
26-
: base(bytes, shape, dtype) { NewEagerTensorHandle(); }
25+
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype)
26+
=> NewEagerTensorHandle();
2727

28-
public NDArray(long[] value, Shape? shape = null)
29-
: base(value, shape) { NewEagerTensorHandle(); }
28+
public NDArray(long[] value, Shape? shape = null) : base(value, shape)
29+
=> NewEagerTensorHandle();
3030

31-
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
32-
: base(address, shape, dtype) { NewEagerTensorHandle(); }
31+
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype)
32+
=> NewEagerTensorHandle();
3333

3434
public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone)
3535
{

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Collections.Generic;
2020
using System.Linq;
2121
using System.Text;
22+
using Tensorflow.Util;
2223
using static Tensorflow.Binding;
2324

2425
namespace Tensorflow.NumPy
@@ -35,7 +36,10 @@ public ValueType GetValue(params int[] indices)
3536
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype));
3637
public NDArray ravel() => throw new NotImplementedException("");
3738
public void shuffle(NDArray nd) => np.random.shuffle(nd);
38-
public Array ToMuliDimArray<T>() => throw new NotImplementedException("");
39+
40+
public unsafe Array ToMultiDimArray<T>() where T : unmanaged
41+
=> NDArrayConverter.ToMultiDimArray<T>(this);
42+
3943
public byte[] ToByteArray() => BufferToArray();
4044
public override string ToString() => NDArrayRender.ToString(this);
4145

src/TensorFlowNET.Keras/Saving/hdf5_format.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,19 +273,19 @@ private static void WriteDataset(long f, string name, Tensor data)
273273
switch (data.dtype)
274274
{
275275
case TF_DataType.TF_FLOAT:
276-
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
276+
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
277277
break;
278278
case TF_DataType.TF_DOUBLE:
279-
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMuliDimArray<double>());
279+
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMultiDimArray<double>());
280280
break;
281281
case TF_DataType.TF_INT32:
282-
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMuliDimArray<int>());
282+
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMultiDimArray<int>());
283283
break;
284284
case TF_DataType.TF_INT64:
285-
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMuliDimArray<long>());
285+
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMultiDimArray<long>());
286286
break;
287287
default:
288-
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
288+
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
289289
break;
290290
}
291291
}

test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ public void array()
5050
AssetSequenceEqual(new[] { 1, 2, 3, 4, 5, 6 }, x.ToArray<int>());
5151
}
5252

53+
[TestMethod]
54+
public void to_multi_dim_array()
55+
{
56+
var x1 = np.arange(12);
57+
var y1 = x1.ToMultiDimArray<int>();
58+
AssetSequenceEqual((int[])y1, x1.ToArray<int>());
59+
60+
var x2 = np.arange(12).reshape((2, 6));
61+
var y2 = (int[,])x2.ToMultiDimArray<int>();
62+
Assert.AreEqual(x2[0, 5], y2[0, 5]);
63+
64+
var x3 = np.arange(12).reshape((2, 2, 3));
65+
var y3 = (int[,,])x3.ToMultiDimArray<int>();
66+
Assert.AreEqual(x3[0, 1, 2], y3[0, 1, 2]);
67+
}
68+
5369
[TestMethod]
5470
public void eye()
5571
{

0 commit comments

Comments
 (0)