Skip to content

Commit 400cde2

Browse files
committed
Fix MapDataset.
1 parent d3f19f4 commit 400cde2

File tree

6 files changed

+29
-5
lines changed

6 files changed

+29
-5
lines changed

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55
using Tensorflow.Framework.Models;
6+
using static Tensorflow.Binding;
67

78
namespace Tensorflow
89
{
@@ -98,6 +99,20 @@ public IDatasetV2 apply_options()
9899
return dataset;
99100
}
100101

102+
public Tensor dataset_cardinality(string name = null)
103+
{
104+
if (tf.Context.executing_eagerly())
105+
{
106+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
107+
"DatasetCardinality", name,
108+
null,
109+
variant_tensor);
110+
return results[0];
111+
}
112+
113+
throw new NotImplementedException("");
114+
}
115+
101116
public override string ToString()
102117
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";
103118

@@ -117,7 +132,9 @@ public override string ToString()
117132
break;
118133
}
119134

120-
yield return (results[0], results.Length == 1 ? null : results[1]);
135+
yield return results.Length == 2
136+
? (results[0], results[1])
137+
: (null, results[0]);
121138
}
122139
}
123140

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,7 @@ IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
7474
/// </summary>
7575
/// <returns></returns>
7676
IDatasetV2 apply_options();
77+
78+
Tensor dataset_cardinality(string name = null);
7779
}
7880
}

src/TensorFlowNET.Core/Data/MapDataset.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ public MapDataset(IDatasetV2 input_dataset,
1515
bool preserve_cardinality = false,
1616
bool use_legacy_function = false) : base(input_dataset)
1717
{
18-
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
19-
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input");
18+
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
19+
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
2020
var output = map_func(input);
2121
func.ToGraph(input, output);
22-
22+
2323
structure = func.OutputStructure;
2424

2525
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible
130130
return new ForwardBackwardCall(functions, args, tape_watching: true);
131131
}
132132

133+
public override string ToString()
134+
=> Name;
135+
133136
public void Dispose()
134137
{
135138
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);

src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs renamed to src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
namespace Tensorflow.Keras.ArgsDefinition
44
{
5-
public class TensorLikeDataAdapterArgs
5+
public class DataAdapterArgs
66
{
77
public Tensor X { get; set; }
88
public Tensor Y { get; set; }
9+
public IDatasetV2 Dataset { get; set; }
910
public int BatchSize { get; set; } = 32;
1011
public int Steps { get; set; }
1112
public int Epochs { get; set; }

src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ public class DataHandlerArgs
66
{
77
public Tensor X { get; set; }
88
public Tensor Y { get; set; }
9+
public IDatasetV2 Dataset { get; set; }
910
public int BatchSize { get; set; } = 32;
1011
public int StepsPerEpoch { get; set; } = -1;
1112
public int InitialEpoch { get; set; } = 0;

0 commit comments

Comments
 (0)