Skip to content

Commit 436afe9

Browse files
committed
tf.data.Dataset.from_tensor #446
1 parent 68df1b7 commit 436afe9

File tree

5 files changed

+95
-1
lines changed

5 files changed

+95
-1
lines changed

src/TensorFlowNET.Core/Data/DatasetManager.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45
using Tensorflow.Data;
@@ -7,6 +8,20 @@ namespace Tensorflow
78
{
89
public class DatasetManager
910
{
11+
public IDatasetV2 from_generator<T>(IEnumerable<T> generator, TF_DataType[] output_types, TensorShape[] output_shapes)
12+
=> new GeneratorDataset();
13+
14+
/// <summary>
15+
/// Creates a `Dataset` with a single element, comprising the given tensors.
16+
/// </summary>
17+
/// <param name="tensors"></param>
18+
/// <returns></returns>
19+
public IDatasetV2 from_tensor(NDArray tensors)
20+
=> new TensorDataset(tensors);
21+
22+
public IDatasetV2 from_tensor(Tensor tensors)
23+
=> new TensorDataset(tensors);
24+
1025
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
1126
=> new TensorSliceDataset(features, labels);
1227

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Data
6+
{
7+
public class GeneratorDataset : DatasetSource
8+
{
9+
10+
}
11+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow
9+
{
10+
/// <summary>
11+
/// A `Dataset` with a single element.
12+
/// </summary>
13+
public class TensorDataset : DatasetSource
14+
{
15+
public TensorDataset(Tensor element)
16+
{
17+
_tensors = new[] { element };
18+
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
19+
structure = batched_spec.Select(x => x._unbatch()).ToArray();
20+
21+
variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
22+
}
23+
24+
public TensorDataset(NDArray element)
25+
{
26+
_tensors = new[] { tf.convert_to_tensor(element) };
27+
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
28+
structure = batched_spec.ToArray();
29+
30+
variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
31+
}
32+
}
33+
}

src/TensorFlowNET.Core/Operations/dataset_ops.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ namespace Tensorflow
88
{
99
public class dataset_ops
1010
{
11+
public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null)
12+
{
13+
if (tf.Context.executing_eagerly())
14+
{
15+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
16+
"TensorDataset", name,
17+
null,
18+
new object[]
19+
{
20+
components,
21+
"output_shapes", output_shapes
22+
});
23+
return results[0];
24+
}
25+
26+
throw new NotImplementedException("");
27+
}
28+
1129
/// <summary>
1230
/// Creates a dataset that emits each dim-0 slice of `components` once.
1331
/// </summary>

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
33
using System.Collections.Generic;
4+
using System.Linq;
45
using System.Text;
6+
using Tensorflow.Keras;
57
using Tensorflow.UnitTest;
68
using static Tensorflow.Binding;
79

@@ -62,6 +64,21 @@ public void FromTensorSlices()
6264
Assert.AreEqual(5, n);
6365
}
6466

67+
[TestMethod]
68+
public void FromTensor()
69+
{
70+
var X = new[] { 2013, 2014, 2015, 2016, 2017 };
71+
72+
var dataset = tf.data.Dataset.from_tensor(X);
73+
int n = 0;
74+
foreach (var x in dataset)
75+
{
76+
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
77+
n += 1;
78+
}
79+
Assert.AreEqual(1, n);
80+
}
81+
6582
[TestMethod]
6683
public void Shard()
6784
{

0 commit comments

Comments
 (0)