Skip to content

Commit e73ed66

Browse files
committed
add SafeTensorHandle to manage tensor handle reference.
1 parent f3cbd85 commit e73ed66

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+271
-322
lines changed

src/TensorFlowNET.Console/Tensorflow.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
</PropertyGroup>
2020

2121
<ItemGroup>
22-
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.5.0" />
22+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.5.0" />
2323
</ItemGroup>
2424

2525
<ItemGroup>

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public partial class c_api
9999
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values);
100100

101101
[DllImport(TensorFlowLibName)]
102-
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, SafeStatusHandle status);
102+
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, SafeTensorHandle value, SafeStatusHandle status);
103103

104104
[DllImport(TensorFlowLibName)]
105105
public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value);

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ public static int len(object a)
164164
return arr.Count;
165165
case ICollection arr:
166166
return arr.Count;
167-
case NDArray ndArray:
168-
return ndArray.ndim == 0 ? 1 : (int)ndArray.dims[0];
169167
case IEnumerable enumerable:
170168
return enumerable.OfType<object>().Count();
171169
case Shape arr:

src/TensorFlowNET.Core/Data/MnistDataSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class MnistDataSet : DataSetBase
1010
public int EpochsCompleted { get; private set; }
1111
public int IndexInEpoch { get; private set; }
1212

13-
public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
13+
public MnistDataSet(NDArray images, NDArray labels, TF_DataType dataType, bool reshape)
1414
{
1515
EpochsCompleted = 0;
1616
IndexInEpoch = 0;

src/TensorFlowNET.Core/Data/ModelLoadSetting.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ public class ModelLoadSetting
66
{
77
public string TrainDir { get; set; }
88
public bool OneHot { get; set; }
9-
public Type DataType { get; set; } = typeof(float);
9+
public TF_DataType DataType { get; set; } = TF_DataType.TF_FLOAT;
1010
public bool ReShape { get; set; }
1111
public int ValidationSize { get; set; } = 5000;
1212
public int? TrainSize { get; set; }

src/TensorFlowNET.Core/DisposableObject.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private void Dispose(bool disposing)
4848
}
4949

5050
// free unmanaged memory
51-
if (_handle != IntPtr.Zero)
51+
// if (_handle != IntPtr.Zero)
5252
{
5353
// Call the appropriate methods to clean up
5454
// unmanaged resources here.

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public EagerTensor(Array array, Shape shape) : base(array, shape)
5656
public EagerTensor(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype)
5757
=> NewEagerTensorHandle(_handle);
5858

59-
void NewEagerTensorHandle(IntPtr h)
59+
void NewEagerTensorHandle(SafeTensorHandle h)
6060
{
6161
_id = ops.uid();
6262
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
303303
/// <param name="t">const tensorflow::Tensor&amp;</param>
304304
/// <returns>TFE_TensorHandle*</returns>
305305
[DllImport(TensorFlowLibName)]
306-
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status);
306+
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
307307

308308
[DllImport(TensorFlowLibName)]
309309
public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t);
@@ -334,7 +334,7 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
334334
/// <param name="status">TF_Status*</param>
335335
/// <returns></returns>
336336
[DllImport(TensorFlowLibName)]
337-
public static extern IntPtr TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
337+
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
338338

339339

340340
/// <summary>

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,5 @@ public static implicit operator NDArray(float value)
4646

4747
public static implicit operator NDArray(double value)
4848
=> new NDArray(value);
49-
50-
public static implicit operator Tensor(NDArray nd)
51-
=> nd?._tensor;
52-
53-
public static implicit operator NDArray(Tensor tensor)
54-
=> new NDArray(tensor);
5549
}
5650
}

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray this[params int[] index]
11+
public NDArray this[params int[] indices]
1212
{
13-
get => GetData(index.Select(x => new Slice
13+
get => GetData(indices.Select(x => new Slice
1414
{
1515
Start = x,
1616
Stop = x + 1,
1717
IsIndex = true
1818
}));
1919

20-
set => SetData(index.Select(x =>
20+
set => SetData(indices.Select(x =>
2121
{
2222
if(x < 0)
2323
x = (int)dims[0] + x;
@@ -57,12 +57,37 @@ public NDArray this[NDArray mask]
5757

5858
NDArray GetData(IEnumerable<Slice> slices)
5959
{
60-
var tensor = _tensor[slices.ToArray()];
61-
return new NDArray(tensor);
60+
if (shape.IsScalar)
61+
return GetScalar();
62+
63+
var tensor = base[slices.ToArray()];
64+
if (tensor.Handle == null)
65+
tensor = tf.defaultSession.eval(tensor);
66+
return new NDArray(tensor.Handle);
67+
}
68+
69+
unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged
70+
{
71+
var offset = (ulong)ShapeHelper.GetOffset(shape, indices);
72+
return *((T*)data + offset);
73+
}
74+
75+
NDArray GetScalar()
76+
{
77+
var array = new NDArray(Shape.Scalar, dtype: dtype);
78+
unsafe
79+
{
80+
var src = (byte*)data + dtypesize;
81+
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize);
82+
}
83+
return array;
6284
}
6385

6486
NDArray GetData(int[] indices, int axis = 0)
6587
{
88+
if (shape.IsScalar)
89+
return GetScalar();
90+
6691
if(axis == 0)
6792
{
6893
var dims = shape.as_int_list();

0 commit comments

Comments
 (0)