Skip to content

Commit 3c020dc

Browse files
committed
fix memory crash when index < 0.
1 parent 9849829 commit 3c020dc

File tree

4 files changed

+38
-9
lines changed

4 files changed

+38
-9
lines changed

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,32 @@ public partial class NDArray
1010
{
1111
public NDArray this[params int[] index]
1212
{
13-
get => _tensor[index.Select(x => new Slice
13+
get => GetData(index.Select(x => new Slice
1414
{
1515
Start = x,
1616
Stop = x + 1,
1717
IsIndex = true
18-
}).ToArray()];
18+
}));
1919

20-
set => SetData(index.Select(x => new Slice
20+
set => SetData(index.Select(x =>
2121
{
22-
Start = x,
23-
Stop = x + 1,
24-
IsIndex = true
22+
if(x < 0)
23+
x = (int)dims[0] + x;
24+
25+
var slice = new Slice
26+
{
27+
Start = x,
28+
Stop = x + 1,
29+
IsIndex = true
30+
};
31+
32+
return slice;
2533
}), value);
2634
}
2735

2836
public NDArray this[params Slice[] slices]
2937
{
30-
get => _tensor[slices];
38+
get => GetData(slices);
3139
set => SetData(slices, value);
3240
}
3341

@@ -44,6 +52,11 @@ public NDArray this[NDArray mask]
4452
}
4553
}
4654

55+
NDArray GetData(IEnumerable<Slice> slices)
56+
{
57+
return _tensor[slices.ToArray()];
58+
}
59+
4760
void SetData(IEnumerable<Slice> slices, NDArray array)
4861
=> SetData(slices, array, -1, slices.Select(x => 0).ToArray());
4962

@@ -61,7 +74,10 @@ void SetData(IEnumerable<Slice> slices, NDArray array, int currentNDim, int[] in
6174
{
6275

6376
if (slice.Step != 1)
64-
throw new NotImplementedException("");
77+
throw new NotImplementedException("slice.step should == 1");
78+
79+
if (slice.Start < 0)
80+
throw new NotImplementedException("slice.start should > -1");
6581

6682
indices[indices.Length - 1] = slice.Start ?? 0;
6783
var offset = (ulong)ShapeHelper.GetOffset(shape, indices);

src/TensorFlowNET.Core/NumPy/ShapeHelper.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ public static long GetOffset(Shape shape, params int[] indices)
8181
for (int i = 0; i < indices.Length; i++)
8282
offset += strides[i] * indices[i];
8383

84+
if (offset < 0)
85+
throw new NotImplementedException("");
86+
8487
return offset;
8588
}
8689
}

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public IntPtr StringTensor(byte[][] buffer, Shape shape)
2929

3030
var tstr = c_api.TF_TensorData(handle);
3131
#if TRACK_TENSOR_LIFE
32-
print($"New TString 0x{handle.ToString("x16")} {AllocationType} Data: 0x{tstr.ToString("x16")}");
32+
print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}");
3333
#endif
3434
for (int i = 0; i < buffer.Length; i++)
3535
{

test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Text;
66
using Tensorflow;
77
using Tensorflow.NumPy;
8+
using static Tensorflow.Binding;
89

910
namespace TensorFlowNET.UnitTest.NumPy
1011
{
@@ -53,5 +54,14 @@ public void slice_string_params()
5354
Assert.AreEqual(y.shape, (1, 2));
5455
Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2)));
5556
}
57+
58+
[TestMethod]
59+
public void slice_out_bound()
60+
{
61+
var input_shape = tf.constant(new int[] { 1, 1 });
62+
var input_shape_val = input_shape.numpy();
63+
input_shape_val[(int)input_shape.size - 1] = 1;
64+
input_shape.Dispose();
65+
}
5666
}
5767
}

0 commit comments

Comments
 (0)