Skip to content

Commit a3cf7ae

Browse files
committed
Add CacheDataset.
1 parent 436afe9 commit a3cf7ae

File tree

8 files changed

+177
-0
lines changed

8 files changed

+177
-0
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
public class CacheDataset : UnaryUnchangedStructureDataset
9+
{
10+
Tensor _filename;
11+
public CacheDataset(IDatasetV2 input_dataset,
12+
string filename = "") :
13+
base(input_dataset)
14+
{
15+
_filename = tf.convert_to_tensor(filename, dtype: TF_DataType.TF_STRING, name: "filename");
16+
variant_tensor = ops.cache_dataset_v2(input_dataset.variant_tensor,
17+
_filename,
18+
ops.dummy_memory_cache(),
19+
output_types,
20+
output_shapes);
21+
}
22+
}
23+
}

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ public class DatasetV2 : IDatasetV2
2323

2424
public TensorSpec[] element_spec => structure;
2525

26+
public IDatasetV2 cache(string filename = "")
27+
=> new CacheDataset(this, filename: filename);
28+
2629
public IDatasetV2 take(int count = -1)
2730
=> new TakeDataset(this, count: count);
2831

@@ -47,6 +50,16 @@ public IDatasetV2 skip(int count)
4750
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
4851
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
4952

53+
public IDatasetV2 map(Func<Tensor, Tensor> map_func,
54+
bool use_inter_op_parallelism = true,
55+
bool preserve_cardinality = false,
56+
bool use_legacy_function = false)
57+
=> new MapDataset(this,
58+
map_func,
59+
use_inter_op_parallelism: use_inter_op_parallelism,
60+
preserve_cardinality: preserve_cardinality,
61+
use_legacy_function: use_legacy_function);
62+
5063
public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
5164
=> new ModelDataset(this, algorithm, cpu_budget);
5265

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq.Expressions;
34
using System.Text;
45
using Tensorflow.Framework.Models;
56

@@ -17,6 +18,13 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
1718

1819
TensorSpec[] structure { get; set; }
1920

21+
/// <summary>
22+
/// Caches the elements in this dataset.
23+
/// </summary>
24+
/// <param name="filename"></param>
25+
/// <returns></returns>
26+
IDatasetV2 cache(string filename="");
27+
2028
/// <summary>
2129
///
2230
/// </summary>
@@ -49,6 +57,11 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
4957

5058
IDatasetV2 optimize(string[] optimizations, string[] optimization_configs);
5159

60+
IDatasetV2 map(Func<Tensor, Tensor> map_func,
61+
bool use_inter_op_parallelism = true,
62+
bool preserve_cardinality = false,
63+
bool use_legacy_function = false);
64+
5265
IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);
5366

5467
/// <summary>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
/// <summary>
8+
/// A `Dataset` that maps a function over elements in its input.
9+
/// </summary>
10+
public class MapDataset : UnaryDataset
11+
{
12+
public MapDataset(IDatasetV2 input_dataset,
13+
Func<Tensor, Tensor> map_func,
14+
bool use_inter_op_parallelism = true,
15+
bool preserve_cardinality = false,
16+
bool use_legacy_function = false) : base(input_dataset)
17+
{
18+
foreach(var input in input_dataset)
19+
{
20+
var data = map_func(input.Item1);
21+
}
22+
23+
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
24+
output_types,
25+
output_shapes);
26+
}
27+
}
28+
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
381381
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
382382
status.Check(true);
383383
break;
384+
case TF_AttrType.TF_ATTR_FUNC:
385+
c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length);
386+
break;
384387
default:
385388
throw new NotImplementedException($"SetOpAttrScalar for {type}");
386389
}

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
196196
[DllImport(TensorFlowLibName)]
197197
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value);
198198

199+
[DllImport(TensorFlowLibName)]
200+
public static extern void TFE_OpSetAttrFunctionName(SafeOpHandle op, string attr_name, string data, int length);
201+
199202
/// <summary>
200203
///
201204
/// </summary>

src/TensorFlowNET.Core/Operations/dataset_ops.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,24 @@ public Tensor dummy_seed_generator(string name = null)
155155
throw new NotImplementedException("");
156156
}
157157

158+
public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache,
159+
TF_DataType[] output_types, TensorShape[] output_shapes,
160+
string name = null)
161+
{
162+
if (tf.Context.executing_eagerly())
163+
{
164+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
165+
"CacheDatasetV2", name,
166+
null,
167+
input_dataset, filename, cache,
168+
"output_types", output_types,
169+
"output_shapes", output_shapes);
170+
return results[0];
171+
}
172+
173+
throw new NotImplementedException("");
174+
}
175+
158176
/// <summary>
159177
/// Creates a dataset that batches `batch_size` elements from `input_dataset`.
160178
/// </summary>
@@ -187,6 +205,24 @@ public Tensor batch_dataset_v2(Tensor input_dataset, Tensor buffer_size,
187205
throw new NotImplementedException("");
188206
}
189207

208+
/// <summary>
209+
///
210+
/// </summary>
211+
/// <param name="name"></param>
212+
/// <returns></returns>
213+
public Tensor dummy_memory_cache(string name = "")
214+
{
215+
if (tf.Context.executing_eagerly())
216+
{
217+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
218+
"DummyMemoryCache", name,
219+
null);
220+
return results[0];
221+
}
222+
223+
throw new NotImplementedException("");
224+
}
225+
190226
/// <summary>
191227
/// Creates a dataset that asynchronously prefetches elements from `input_dataset`.
192228
/// </summary>
@@ -354,6 +390,33 @@ public ITensorOrOperation make_iterator(Tensor dataset, Tensor iterator, string
354390
throw new NotImplementedException("");
355391
}
356392

393+
/// <summary>
394+
///
395+
/// </summary>
396+
/// <param name="dataset"></param>
397+
/// <param name="iterator"></param>
398+
/// <param name="name"></param>
399+
/// <returns></returns>
400+
public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes,
401+
bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null)
402+
{
403+
if (tf.Context.executing_eagerly())
404+
{
405+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
406+
"MapDataset", name,
407+
null,
408+
dataset, new Tensor[0],
409+
"f", "MapDataset",
410+
"output_types", output_types,
411+
"output_shapes", output_shapes,
412+
"use_inter_op_parallelism", use_inter_op_parallelism,
413+
"preserve_cardinality", preserve_cardinality);
414+
return results[0];
415+
}
416+
417+
throw new NotImplementedException("");
418+
}
419+
357420
/// <summary>
358421
/// A container for an iterator resource.
359422
/// </summary>

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55
using System.Text;
6+
using Tensorflow;
67
using Tensorflow.Keras;
78
using Tensorflow.UnitTest;
89
using static Tensorflow.Binding;
@@ -116,5 +117,35 @@ public void Skip()
116117
value ++;
117118
}
118119
}
120+
121+
[TestMethod, Ignore]
122+
public void Map()
123+
{
124+
long value = 0;
125+
126+
var dataset = tf.data.Dataset.range(3);
127+
var dataset1 = dataset.map(x => x);
128+
129+
foreach (var item in dataset)
130+
{
131+
Assert.AreEqual(value, (long)item.Item1);
132+
value++;
133+
}
134+
}
135+
136+
[TestMethod]
137+
public void Cache()
138+
{
139+
long value = 0;
140+
141+
var dataset = tf.data.Dataset.range(5);
142+
dataset = dataset.cache();
143+
144+
foreach (var item in dataset)
145+
{
146+
Assert.AreEqual(value, (long)item.Item1);
147+
value++;
148+
}
149+
}
119150
}
120151
}

0 commit comments

Comments
 (0)