Skip to content

Commit d8afa8c

Browse files
committed
Fix input dtype for MapDataset. #666
1 parent c01b4dd commit d8afa8c

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/TensorFlowNET.Core/Data/MapDataset.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using Tensorflow.Functions;
3+
using static Tensorflow.Binding;
34

45
namespace Tensorflow
56
{
@@ -14,7 +15,12 @@ public MapDataset(IDatasetV2 input_dataset,
1415
bool preserve_cardinality = false,
1516
bool use_legacy_function = false) : base(input_dataset)
1617
{
17-
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
18+
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
19+
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input");
20+
var output = map_func(input);
21+
func.ToGraph(input, output);
22+
23+
structure = func.OutputStructure;
1824

1925
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
2026
func,

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ public void ToGraph(Tensors inputs, Tensors outputs)
109109
inputs,
110110
outputs,
111111
null);
112+
113+
OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
112114
}
113115

114116
public Tensors Invoke(Tensors inputs)

0 commit comments

Comments
 (0)