Skip to content

Commit a174a84

Browse files
committed
tf.data.Dataset shard() #446
1 parent 5e0e8b1 commit a174a84

File tree

5 files changed

+84
-0
lines changed

5 files changed

+84
-0
lines changed

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
3535
public IDatasetV2 repeat(int count = -1)
3636
=> new RepeatDataset(this, count: count);
3737

38+
public IDatasetV2 shard(int num_shards, int index)
39+
=> new ShardDataset(this, num_shards, index);
40+
3841
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
3942
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
4043

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
2424
/// <returns></returns>
2525
IDatasetV2 repeat(int count = -1);
2626

27+
/// <summary>
28+
/// Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
29+
/// </summary>
30+
/// <param name="num_shards">The number of shards operating in parallel</param>
31+
/// <param name="index">The worker index</param>
32+
/// <returns></returns>
33+
IDatasetV2 shard(int num_shards, int index);
34+
2735
IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);
2836

2937
IDatasetV2 batch(int batch_size, bool drop_remainder = false);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
/// <summary>
9+
/// A `Dataset` for sharding its input.
10+
/// </summary>
11+
public class ShardDataset : UnaryUnchangedStructureDataset
12+
{
13+
Tensor _num_shards;
14+
Tensor _index;
15+
16+
public ShardDataset(IDatasetV2 input_dataset,
17+
int num_shards,
18+
int index) : base(input_dataset)
19+
{
20+
_num_shards = tf.convert_to_tensor(num_shards, dtype: TF_DataType.TF_INT64, name: "num_shards");
21+
_index = tf.convert_to_tensor(index, dtype: TF_DataType.TF_INT64, name: "index");
22+
23+
variant_tensor = ops.shard_dataset
24+
(input_dataset.variant_tensor,
25+
num_shards: _num_shards,
26+
index: _index,
27+
input_dataset.output_types,
28+
input_dataset.output_shapes);
29+
}
30+
}
31+
}

src/TensorFlowNET.Core/Operations/dataset_ops.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] o
6565
throw new NotImplementedException("");
6666
}
6767

68+
public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index,
69+
TF_DataType[] output_types, TensorShape[] output_shapes,
70+
bool require_non_empty = false, string name = null)
71+
{
72+
if (tf.Context.executing_eagerly())
73+
{
74+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
75+
"ShardDataset", name,
76+
null,
77+
input_dataset, num_shards, index,
78+
"require_non_empty", require_non_empty,
79+
"output_types", output_types,
80+
"output_shapes", output_shapes);
81+
return results[0];
82+
}
83+
84+
throw new NotImplementedException("");
85+
}
86+
6887
public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
6988
Tensor seed, Tensor seed2, Tensor seed_generator,
7089
TF_DataType[] output_types, TensorShape[] output_shapes,

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,28 @@ public void FromTensorSlices()
6161
}
6262
Assert.AreEqual(5, n);
6363
}
64+
65+
[TestMethod]
66+
public void Shard()
67+
{
68+
long value = 0;
69+
70+
var dataset1 = tf.data.Dataset.range(10);
71+
var dataset2 = dataset1.shard(num_shards: 3, index: 0);
72+
73+
foreach (var item in dataset2)
74+
{
75+
Assert.AreEqual(value, (long)item.Item1);
76+
value += 3;
77+
}
78+
79+
value = 1;
80+
var dataset3 = dataset1.shard(num_shards: 3, index: 1);
81+
foreach (var item in dataset3)
82+
{
83+
Assert.AreEqual(value, (long)item.Item1);
84+
value += 3;
85+
}
86+
}
6487
}
6588
}

0 commit comments

Comments
 (0)