Skip to content

Commit 734430d

Browse files
committed
slice assign works
1 parent 7f0b9b6 commit 734430d

File tree

14 files changed

+173
-88
lines changed

14 files changed

+173
-88
lines changed

src/TensorFlowNET.Console/Tensorflow.Console.csproj

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
<PlatformTarget>x64</PlatformTarget>
1515
</PropertyGroup>
1616

17+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
18+
<DefineConstants>DEBUG;TRACE</DefineConstants>
19+
</PropertyGroup>
20+
1721
<ItemGroup>
18-
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" />
22+
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.5.0" />
1923
</ItemGroup>
2024

2125
<ItemGroup>

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,18 @@ public static float time()
271271
}
272272
}
273273

274-
public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2)
274+
public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2, Axis axis = null)
275275
where T : unmanaged
276276
{
277-
/*var a = t1.AsIterator<T>();
278-
var b = t2.AsIterator<T>();
279-
while (a.HasNext() && b.HasNext())
280-
yield return (a.MoveNext(), b.MoveNext());*/
281-
throw new NotImplementedException("");
277+
if (axis == null)
278+
{
279+
var a = t1.Data<T>();
280+
var b = t2.Data<T>();
281+
for (int i = 0; i < a.Length; i++)
282+
yield return (a[i], b[i]);
283+
}
284+
else
285+
throw new NotImplementedException("");
282286
}
283287

284288
public static IEnumerable<(T1, T2)> zip<T1, T2>(IList<T1> t1, IList<T2> t2)

src/TensorFlowNET.Core/Data/MnistModelLoader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
166166
for (int row = 0; row < num_labels; row++)
167167
{
168168
var col = labels[row];
169-
labels_one_hot.SetData(1.0, row, col);
169+
labels_one_hot[row, col] = 1.0;
170170
}
171171

172172
return labels_one_hot;

src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ bool Equals(NDArray x, NDArray y)
2525
{
2626
if (x.ndim != y.ndim)
2727
return false;
28+
else if (x.size != y.size)
29+
return false;
30+
else if (x.dtype != y.dtype)
31+
return false;
2832

2933
return Enumerable.SequenceEqual(x.ToByteArray(), y.ToByteArray());
3034
}

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ public partial class NDArray
88
{
99
public void Deconstruct(out byte blue, out byte green, out byte red)
1010
{
11-
var data = Data<byte>();
11+
var data = ToArray<byte>();
1212
blue = data[0];
1313
green = data[1];
1414
red = data[2];
@@ -17,23 +17,23 @@ public void Deconstruct(out byte blue, out byte green, out byte red)
1717
public static implicit operator NDArray(Array array)
1818
=> new NDArray(array);
1919

20-
public static implicit operator bool(NDArray nd)
21-
=> nd._tensor.ToArray<bool>()[0];
20+
public unsafe static implicit operator bool(NDArray nd)
21+
=> *(bool*)nd.data;
2222

23-
public static implicit operator byte(NDArray nd)
24-
=> nd._tensor.ToArray<byte>()[0];
23+
public unsafe static implicit operator byte(NDArray nd)
24+
=> *(byte*)nd.data;
2525

26-
public static implicit operator byte[](NDArray nd)
27-
=> nd.ToByteArray();
26+
public unsafe static implicit operator int(NDArray nd)
27+
=> *(int*)nd.data;
2828

29-
public static implicit operator int(NDArray nd)
30-
=> nd._tensor.ToArray<int>()[0];
29+
public unsafe static implicit operator long(NDArray nd)
30+
=> *(long*)nd.data;
3131

32-
public static implicit operator float(NDArray nd)
33-
=> nd._tensor.ToArray<float>()[0];
32+
public unsafe static implicit operator float(NDArray nd)
33+
=> *(float*)nd.data;
3434

35-
public static implicit operator double(NDArray nd)
36-
=> nd._tensor.ToArray<double>()[0];
35+
public unsafe static implicit operator double(NDArray nd)
36+
=> *(double*)nd.data;
3737

3838
public static implicit operator NDArray(bool value)
3939
=> new NDArray(value);

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,76 +8,78 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray this[int index]
11+
public NDArray this[params int[] index]
1212
{
13-
get
13+
get => _tensor[index.Select(x => new Slice
1414
{
15-
return _tensor[index];
16-
}
15+
Start = x,
16+
Stop = x + 1,
17+
IsIndex = true
18+
}).ToArray()];
1719

18-
set
20+
set => SetData(index.Select(x => new Slice
1921
{
22+
Start = x,
23+
Stop = x + 1,
24+
IsIndex = true
25+
}), value);
26+
}
2027

21-
}
28+
public NDArray this[params Slice[] slices]
29+
{
30+
get => _tensor[slices];
31+
set => SetData(slices, value);
2232
}
2333

24-
public NDArray this[params int[] index]
34+
public NDArray this[NDArray mask]
2535
{
2636
get
2737
{
28-
return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
38+
throw new NotImplementedException("");
2939
}
3040

3141
set
3242
{
33-
var offset = ShapeHelper.GetOffset(shape, index);
34-
unsafe
35-
{
36-
if (dtype == TF_DataType.TF_BOOL)
37-
*((bool*)data + offset) = value;
38-
else if (dtype == TF_DataType.TF_UINT8)
39-
*((byte*)data + offset) = value;
40-
else if (dtype == TF_DataType.TF_INT32)
41-
*((int*)data + offset) = value;
42-
else if (dtype == TF_DataType.TF_INT64)
43-
*((long*)data + offset) = value;
44-
else if (dtype == TF_DataType.TF_FLOAT)
45-
*((float*)data + offset) = value;
46-
else if (dtype == TF_DataType.TF_DOUBLE)
47-
*((double*)data + offset) = value;
48-
}
43+
throw new NotImplementedException("");
4944
}
5045
}
5146

52-
public NDArray this[params Slice[] slices]
47+
void SetData(IEnumerable<Slice> slices, NDArray array)
48+
=> SetData(slices, array, -1, slices.Select(x => 0).ToArray());
49+
50+
void SetData(IEnumerable<Slice> slices, NDArray array, int currentNDim, int[] indices)
5351
{
54-
get
55-
{
56-
return _tensor[slices];
57-
}
52+
if (dtype != array.dtype)
53+
throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");
5854

59-
set
55+
if (!slices.Any())
56+
return;
57+
58+
var slice = slices.First();
59+
60+
if (slices.Count() == 1)
6061
{
61-
var pos = _tensor[slices];
62-
var len = value.bytesize;
62+
63+
if (slice.Step != 1)
64+
throw new NotImplementedException("");
65+
66+
indices[indices.Length - 1] = slice.Start ?? 0;
67+
var offset = (ulong)ShapeHelper.GetOffset(shape, indices);
68+
var bytesize = array.bytesize;
6369
unsafe
6470
{
65-
System.Buffer.MemoryCopy(value.data.ToPointer(), pos.TensorDataPointer.ToPointer(), len, len);
71+
var dst = (byte*)data + offset * dtypesize;
72+
System.Buffer.MemoryCopy(array.data.ToPointer(), dst, bytesize, bytesize);
6673
}
67-
// _tensor[slices].assign(constant_op.constant(value));
68-
}
69-
}
7074

71-
public NDArray this[NDArray mask]
72-
{
73-
get
74-
{
75-
throw new NotImplementedException("");
75+
return;
7676
}
7777

78-
set
78+
currentNDim++;
79+
for (var i = slice.Start ?? 0; i < slice.Stop; i++)
7980
{
80-
81+
indices[currentNDim] = i;
82+
SetData(slices.Skip(1), array, currentNDim, indices);
8183
}
8284
}
8385
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.NumPy
8+
{
9+
public partial class NDArray
10+
{
11+
public static NDArray operator +(NDArray lhs, NDArray rhs) => lhs.Tensor + rhs.Tensor;
12+
public static NDArray operator -(NDArray lhs, NDArray rhs) => lhs.Tensor - rhs.Tensor;
13+
public static NDArray operator *(NDArray lhs, NDArray rhs) => lhs.Tensor * rhs.Tensor;
14+
public static NDArray operator /(NDArray lhs, NDArray rhs) => lhs.Tensor / rhs.Tensor;
15+
}
16+
}

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace Tensorflow.NumPy
2525
public partial class NDArray
2626
{
2727
Tensor _tensor;
28+
public Tensor Tensor => _tensor;
2829
public TF_DataType dtype => _tensor.dtype;
2930
public ulong size => _tensor.size;
3031
public ulong dtypesize => _tensor.dtypesize;
@@ -47,15 +48,12 @@ public NDArray[] GetNDArrays()
4748
public ValueType GetValue(params int[] indices)
4849
=> throw new NotImplementedException("");
4950

50-
public void SetData(object value, params int[] indices)
51-
=> throw new NotImplementedException("");
52-
5351
public NDIterator<T> AsIterator<T>(bool autoreset = false) where T : unmanaged
5452
=> throw new NotImplementedException("");
5553

5654
public bool HasNext() => throw new NotImplementedException("");
5755
public T MoveNext<T>() => throw new NotImplementedException("");
58-
public NDArray reshape(Shape newshape) => new NDArray(_tensor, newshape);
56+
public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(_tensor, newshape));
5957
public NDArray astype(Type type) => new NDArray(math_ops.cast(_tensor, type.as_tf_dtype()));
6058
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(_tensor, dtype));
6159
public NDArray ravel() => throw new NotImplementedException("");

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ public class DataHandler
1414
IDataAdapter _adapter;
1515
public IDataAdapter DataAdapter => _adapter;
1616
IDatasetV2 _dataset;
17-
int _inferred_steps;
18-
public int Inferredsteps => _inferred_steps;
19-
int _current_step;
20-
int _step_increment;
21-
public int StepIncrement => _step_increment;
17+
long _inferred_steps;
18+
public long Inferredsteps => _inferred_steps;
19+
long _current_step;
20+
long _step_increment;
21+
public long StepIncrement => _step_increment;
2222
bool _insufficient_data;
23-
int _steps_per_execution_value;
23+
long _steps_per_execution_value;
2424
int _initial_epoch => args.InitialEpoch;
2525
int _epochs => args.Epochs;
2626
IVariableV1 _steps_per_execution;
@@ -30,8 +30,8 @@ public DataHandler(DataHandlerArgs args)
3030
this.args = args;
3131
if (args.StepsPerExecution == null)
3232
{
33-
_steps_per_execution = tf.Variable(1);
34-
_steps_per_execution_value = 1;
33+
_steps_per_execution = tf.Variable(1L);
34+
_steps_per_execution_value = 1L;
3535
}
3636
else
3737
{
@@ -103,7 +103,7 @@ int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
103103
// _adapter.on_epoch_end()
104104
}
105105

106-
public IEnumerable<int> steps()
106+
public IEnumerable<long> steps()
107107
{
108108
_current_step = 0;
109109
while (_current_step < _inferred_steps)

test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ public void TokenizeTextsToSequencesWithOOVPresent()
229229
Assert.AreEqual(9, oov_count);
230230
}
231231

232-
[TestMethod, Ignore("slice assign doesn't work")]
232+
[TestMethod]
233233
public void PadSequencesWithDefaults()
234234
{
235235
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
@@ -241,12 +241,12 @@ public void PadSequencesWithDefaults()
241241
Assert.AreEqual(4, padded.dims[0]);
242242
Assert.AreEqual(22, padded.dims[1]);
243243

244-
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19]);
244+
Assert.AreEqual(padded[0, 19], tokenizer.word_index["worst"]);
245245
for (var i = 0; i < 8; i++)
246-
Assert.AreEqual(0, padded[0, i]);
247-
Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10]);
246+
Assert.AreEqual(padded[0, i], 0);
247+
Assert.AreEqual(padded[1, 10], tokenizer.word_index["proud"]);
248248
for (var i = 0; i < 20; i++)
249-
Assert.AreNotEqual(0, padded[1, i]);
249+
Assert.AreNotEqual(padded[1, i], 0);
250250
}
251251

252252
[TestMethod, Ignore("slice assign doesn't work")]

0 commit comments

Comments
 (0)