Skip to content

Commit 94601f5

Browse files
committed
throw exception from SetData
1 parent 35bcda9 commit 94601f5

File tree

4 files changed

+43
-9
lines changed

4 files changed

+43
-9
lines changed

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@ public static implicit operator NDArray(Array array)
1818
=> new NDArray(array);
1919

2020
public unsafe static implicit operator bool(NDArray nd)
21-
=> *(bool*)nd.data;
21+
=> nd.dtype == TF_DataType.TF_BOOL ? *(bool*)nd.data : NDArrayConverter.Scalar<bool>(nd);
2222

2323
public unsafe static implicit operator byte(NDArray nd)
24-
=> *(byte*)nd.data;
24+
=> nd.dtype == TF_DataType.TF_INT8 ? *(byte*)nd.data : NDArrayConverter.Scalar<byte>(nd);
2525

2626
public unsafe static implicit operator int(NDArray nd)
27-
=> *(int*)nd.data;
27+
=> nd.dtype == TF_DataType.TF_INT32 ? *(int*)nd.data : NDArrayConverter.Scalar<int>(nd);
2828

2929
public unsafe static implicit operator long(NDArray nd)
30-
=> *(long*)nd.data;
30+
=> nd.dtype == TF_DataType.TF_INT64 ? *(long*)nd.data : NDArrayConverter.Scalar<long>(nd);
3131

3232
public unsafe static implicit operator float(NDArray nd)
33-
=> *(float*)nd.data;
33+
=> nd.dtype == TF_DataType.TF_FLOAT ? *(float*)nd.data : NDArrayConverter.Scalar<float>(nd);
3434

3535
public unsafe static implicit operator double(NDArray nd)
36-
=> *(double*)nd.data;
36+
=> nd.dtype == TF_DataType.TF_DOUBLE ? *(double*)nd.data : NDArrayConverter.Scalar<double>(nd);
3737

3838
public static implicit operator NDArray(bool value)
3939
=> new NDArray(value);

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ void SetData(IEnumerable<Slice> slices, NDArray array)
175175
unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim)
176176
{
177177
if (dtype != src.dtype)
178-
src = src.astype(dtype);
179-
// throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");
178+
// src = src.astype(dtype);
179+
throw new ArrayTypeMismatchException($"Required dtype {dtype} but {src.dtype} is assigned.");
180180

181181
if (!slices.Any())
182182
return;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.NumPy
7+
{
8+
public class NDArrayConverter
9+
{
10+
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
11+
=> nd.dtype switch
12+
{
13+
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
14+
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data),
15+
_ => throw new NotImplementedException("")
16+
};
17+
18+
static T Scalar<T>(float input)
19+
=> Type.GetTypeCode(typeof(T)) switch
20+
{
21+
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
22+
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
23+
_ => throw new NotImplementedException("")
24+
};
25+
26+
static T Scalar<T>(long input)
27+
=> Type.GetTypeCode(typeof(T)) switch
28+
{
29+
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
30+
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
31+
_ => throw new NotImplementedException("")
32+
};
33+
}
34+
}

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public DataHandler(DataHandlerArgs args)
7878
_insufficient_data = false;
7979
}
8080

81-
int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
81+
long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
8282
{
8383
if (steps_per_epoch > -1)
8484
return steps_per_epoch;

0 commit comments

Comments
 (0)