We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5e0e8b1 commit a174a84Copy full SHA for a174a84
src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -35,6 +35,9 @@ public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
35
public IDatasetV2 repeat(int count = -1)
36
=> new RepeatDataset(this, count: count);
37
38
+ public IDatasetV2 shard(int num_shards, int index)
39
+ => new ShardDataset(this, num_shards, index);
40
+
41
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
42
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
43
src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -24,6 +24,14 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
24
/// <returns></returns>
25
IDatasetV2 repeat(int count = -1);
26
27
+ /// <summary>
28
+ /// Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
29
+ /// </summary>
30
+ /// <param name="num_shards">The number of shards operating in parallel</param>
31
+ /// <param name="index">The worker index</param>
32
+ /// <returns></returns>
33
+ IDatasetV2 shard(int num_shards, int index);
34
IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);
IDatasetV2 batch(int batch_size, bool drop_remainder = false);
src/TensorFlowNET.Core/Data/ShardDataset.cs
@@ -0,0 +1,31 @@
1
+using System;
2
+using System.Collections.Generic;
3
+using System.Text;
4
+using static Tensorflow.Binding;
5
6
+namespace Tensorflow
7
+{
8
9
+ /// A `Dataset` for sharding its input.
10
11
+ public class ShardDataset : UnaryUnchangedStructureDataset
12
+ {
13
+ Tensor _num_shards;
14
+ Tensor _index;
15
16
+ public ShardDataset(IDatasetV2 input_dataset,
17
+ int num_shards,
18
+ int index) : base(input_dataset)
19
20
+ _num_shards = tf.convert_to_tensor(num_shards, dtype: TF_DataType.TF_INT64, name: "num_shards");
21
+ _index = tf.convert_to_tensor(index, dtype: TF_DataType.TF_INT64, name: "index");
22
23
+ variant_tensor = ops.shard_dataset
+ (input_dataset.variant_tensor,
+ num_shards: _num_shards,
+ index: _index,
+ input_dataset.output_types,
+ input_dataset.output_shapes);
+ }
+}
src/TensorFlowNET.Core/Operations/dataset_ops.cs
@@ -65,6 +65,25 @@ public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] o
65
throw new NotImplementedException("");
66
}
67
68
+ public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index,
69
+ TF_DataType[] output_types, TensorShape[] output_shapes,
70
+ bool require_non_empty = false, string name = null)
71
72
+ if (tf.Context.executing_eagerly())
73
74
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
75
+ "ShardDataset", name,
76
+ null,
77
+ input_dataset, num_shards, index,
78
+ "require_non_empty", require_non_empty,
79
+ "output_types", output_types,
80
+ "output_shapes", output_shapes);
81
+ return results[0];
82
83
84
+ throw new NotImplementedException("");
85
86
87
public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
88
Tensor seed, Tensor seed2, Tensor seed_generator,
89
TF_DataType[] output_types, TensorShape[] output_shapes,
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
@@ -61,5 +61,28 @@ public void FromTensorSlices()
61
62
Assert.AreEqual(5, n);
63
64
+ [TestMethod]
+ public void Shard()
+ long value = 0;
+ var dataset1 = tf.data.Dataset.range(10);
+ var dataset2 = dataset1.shard(num_shards: 3, index: 0);
+ foreach (var item in dataset2)
+ Assert.AreEqual(value, (long)item.Item1);
+ value += 3;
+ value = 1;
+ var dataset3 = dataset1.shard(num_shards: 3, index: 1);
+ foreach (var item in dataset3)
0 commit comments