Skip to content

Commit ceccf40

Browse files
committed
Add concatenate in dataset.
1 parent 0273341 commit ceccf40

File tree

4 files changed

+52
-1
lines changed

4 files changed

+52
-1
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Framework;
6+
using Tensorflow.Framework.Models;
7+
using static Tensorflow.Binding;
8+
9+
namespace Tensorflow.Data
10+
{
11+
/// <summary>
12+
/// A `Dataset` that concatenates its input with given dataset.
13+
/// </summary>
14+
public class ConcatenateDataset : DatasetV2
15+
{
16+
IDatasetV2 _input_dataset;
17+
IDatasetV2 _dataset_to_concatenate;
18+
public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate)
19+
{
20+
_input_dataset = input_dataset;
21+
_dataset_to_concatenate = dataset_to_concatenate;
22+
var _structure = new List<TensorSpec>();
23+
foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec))
24+
{
25+
var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape);
26+
_structure.Add(new TensorSpec(shape, dtype: spec.dtype));
27+
}
28+
structure = _structure.ToArray();
29+
30+
variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor,
31+
dataset_to_concatenate.variant_tensor,
32+
output_types, output_shapes);
33+
}
34+
}
35+
}

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.Linq;
5+
using Tensorflow.Data;
56
using Tensorflow.Framework.Models;
67
using static Tensorflow.Binding;
78

@@ -26,6 +27,9 @@ public class DatasetV2 : IDatasetV2
2627
public IDatasetV2 cache(string filename = "")
2728
=> new CacheDataset(this, filename: filename);
2829

30+
public IDatasetV2 concatenate(IDatasetV2 dataset)
31+
=> new ConcatenateDataset(this, dataset);
32+
2933
public IDatasetV2 take(int count = -1)
3034
=> new TakeDataset(this, count: count);
3135

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
2323
/// <returns></returns>
2424
IDatasetV2 cache(string filename = "");
2525

26+
/// <summary>
27+
/// Creates a `Dataset` by concatenating the given dataset with this dataset.
28+
/// </summary>
29+
/// <param name="dataset"></param>
30+
/// <returns></returns>
31+
IDatasetV2 concatenate(IDatasetV2 dataset);
32+
2633
/// <summary>
2734
///
2835
/// </summary>

src/TensorFlowNET.Core/Data/OwnedIterator.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using System;
2+
using System.Linq;
23
using Tensorflow.Framework.Models;
4+
using static Tensorflow.Binding;
35

46
namespace Tensorflow
57
{
@@ -36,7 +38,10 @@ public Tensor[] next()
3638
{
3739
try
3840
{
39-
return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
41+
var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
42+
foreach(var (i, tensor) in enumerate(results))
43+
tensor.set_shape(_element_spec[i].shape);
44+
return results;
4045
}
4146
catch (OutOfRangeError ex)
4247
{

0 commit comments

Comments
 (0)