Skip to content

Commit 64c5157

Browse files
committed
tf.data.Dataset.range #446
1 parent cd43400 commit 64c5157

File tree

17 files changed

+254
-38
lines changed

17 files changed

+254
-38
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,17 @@ public static float time()
265265
yield return (i, values[i]);
266266
}
267267

268-
public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0)
268+
public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0, int step = 1)
269269
{
270270
int i = 0;
271-
foreach(var val in values)
271+
foreach (var val in values)
272272
{
273-
if (i++ < start)
273+
i += step;
274+
275+
if (i < start)
274276
continue;
275-
276-
yield return (i - start, val);
277+
278+
yield return (i - step - start, val);
277279
}
278280
}
279281

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
using NumSharp;
2-
using System;
1+
using System;
32
using System.Collections.Generic;
43
using System.Text;
4+
using Tensorflow.Data;
55

66
namespace Tensorflow
77
{
88
public class DatasetManager
99
{
1010
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
1111
=> new TensorSliceDataset(features, labels);
12+
13+
public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64)
14+
=> new RangeDataset(count, output_type: output_type);
1215
}
1316
}

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,19 @@ public override string ToString()
7878
{
7979
var ownedIterator = new OwnedIterator(this);
8080

81-
bool stop = false;
8281
Tensor[] results = null;
83-
while (!stop)
82+
while (true)
8483
{
8584
try
8685
{
8786
results = ownedIterator.next();
8887
}
8988
catch (StopIteration)
9089
{
91-
stop = true;
90+
break;
9291
}
9392

94-
yield return (results[0], results[1]);
93+
yield return (results[0], results.Length == 1 ? null : results[1]);
9594
}
9695
}
9796

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Framework.Models;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Data
8+
{
9+
public class RangeDataset : DatasetSource
10+
{
11+
Tensor start;
12+
Tensor step;
13+
Tensor stop;
14+
15+
public RangeDataset(int stop,
16+
int start = 0,
17+
int step = 1,
18+
TF_DataType output_type = TF_DataType.TF_INT64)
19+
{
20+
this.start = tf.convert_to_tensor((long)start);
21+
this.step = tf.convert_to_tensor((long)step);
22+
this.stop = tf.convert_to_tensor((long)stop);
23+
24+
structure = new TensorSpec[] { new TensorSpec(new int[0], dtype: output_type) };
25+
variant_tensor = ops.range_dataset(this.start, this.stop, this.step, output_types, output_shapes);
26+
}
27+
}
28+
}

src/TensorFlowNET.Core/Data/TensorSliceDataset.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
using Tensorflow.Framework.Models;
88
using static Tensorflow.Binding;
99

10-
namespace Tensorflow
10+
namespace Tensorflow.Data
1111
{
1212
public class TensorSliceDataset : DatasetSource
1313
{
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition
7+
{
8+
public class DataHandlerArgs
9+
{
10+
public Tensor X { get; set; }
11+
public Tensor Y { get; set; }
12+
public int BatchSize { get; set; } = 32;
13+
public int StepsPerEpoch { get; set; } = -1;
14+
public int InitialEpoch { get; set; } = 0;
15+
public int Epochs { get; set; } = 1;
16+
public bool Shuffle { get; set; } = false;
17+
public int MaxQueueSize { get; set; } = 10;
18+
public int Workers { get; set; } = 1;
19+
public bool UseMultiprocessing { get; set; } = false;
20+
public Model Model { get; set; }
21+
public IVariableV1 StepsPerExecution { get; set; }
22+
}
23+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition
7+
{
8+
public class SequentialArgs : ModelArgs
9+
{
10+
public List<Layer> Layers { get; set; }
11+
}
12+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
6+
namespace Tensorflow.Keras.Engine.DataAdapters
7+
{
8+
/// <summary>
9+
/// Handles iterating over epoch-level `tf.data.Iterator` objects.
10+
/// </summary>
11+
public class DataHandler
12+
{
13+
DataHandlerArgs args;
14+
15+
Tensor x => args.X;
16+
Tensor y => args.Y;
17+
int batch_size => args.BatchSize;
18+
int steps_per_epoch => args.StepsPerEpoch;
19+
int initial_epoch => args.InitialEpoch;
20+
int epochs => args.Epochs;
21+
bool shuffle => args.Shuffle;
22+
int max_queue_size => args.MaxQueueSize;
23+
int workers => args.Workers;
24+
bool use_multiprocessing => args.UseMultiprocessing;
25+
Model model => args.Model;
26+
IVariableV1 steps_per_execution => args.StepsPerExecution;
27+
28+
public DataHandler(DataHandlerArgs args)
29+
{
30+
31+
}
32+
}
33+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Engine.DataAdapters
6+
{
7+
/// <summary>
8+
/// In TF 2.0, tf.data is the preferred API for user to feed in data. In order
9+
/// to simplify the training code path, all the input data object will be
10+
/// converted to `tf.data.Dataset` if possible.
11+
/// </summary>
12+
public interface IDataAdapter
13+
{
14+
/// <summary>
15+
/// Whether the current DataAdapter could handle the input x and y.
16+
/// </summary>
17+
/// <param name="x">input features</param>
18+
/// <param name="y">target labels</param>
19+
/// <returns></returns>
20+
bool CanHandle(Tensor x, Tensor y = null);
21+
}
22+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow.Keras.Engine.DataAdapters
7+
{
8+
/// <summary>
9+
/// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.
10+
/// </summary>
11+
public class TensorLikeDataAdapter : IDataAdapter
12+
{
13+
public TensorLikeDataAdapter()
14+
{
15+
tf.data.Dataset.range(5);
16+
}
17+
18+
public bool CanHandle(Tensor x, Tensor y = null)
19+
{
20+
throw new NotImplementedException();
21+
}
22+
}
23+
}

0 commit comments

Comments
 (0)