Skip to content

Commit bbc2e98

Browse files
committed
IndexLookup, Accumulator.
1 parent a942564 commit bbc2e98

File tree

16 files changed

+202
-6
lines changed

16 files changed

+202
-6
lines changed

src/TensorFlowNET.Core/APIs/tf.strings.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ public Tensor substr(Tensor input, int pos, int len,
6464
public Tensor substr(string input, int pos, int len,
6565
string name = null, string @uint = "BYTE")
6666
=> ops.substr(input, pos, len, @uint: @uint, name: name);
67+
68+
public Tensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
69+
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
6770
}
6871
}
6972
}

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ public IDatasetV2 map(Func<Tensors, Tensors> map_func,
6868
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
6969
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
7070

71+
public OwnedIterator make_one_shot_iterator()
72+
{
73+
if (tf.Context.executing_eagerly())
74+
{
75+
// with ops.colocate_with(self._variant_tensor)
76+
return new OwnedIterator(this);
77+
}
78+
79+
throw new NotImplementedException("");
80+
}
81+
7182
public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
7283
=> new FlatMapDataset(this, map_func);
7384

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ IDatasetV2 map(Func<Tensors, Tensors> map_func,
7272
IDatasetV2 map(Func<Tensors, Tensors> map_func,
7373
int num_parallel_calls);
7474

75+
OwnedIterator make_one_shot_iterator();
76+
7577
IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);
7678

7779
IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);

src/TensorFlowNET.Core/Data/OwnedIterator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void _create_iterator(IDatasetV2 dataset)
2626
dataset = dataset.apply_options();
2727
_dataset = dataset;
2828
_element_spec = dataset.element_spec;
29+
// _flat_output_types =
2930
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
3031
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
3132
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ public class TextVectorizationArgs : PreprocessingLayerArgs
1111
public int MaxTokens { get; set; } = -1;
1212
public string OutputMode { get; set; } = "int";
1313
public int OutputSequenceLength { get; set; } = -1;
14+
public string[] Vocabulary { get; set; }
1415
}
1516
}

src/TensorFlowNET.Core/Operations/string_ops.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,10 @@ public Tensor substr<T>(T input, int pos, int len,
4141
string @uint = "BYTE", string name = null)
4242
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
4343
.SetAttributes(new { unit = @uint }));
44+
45+
public Tensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
46+
{
47+
return null;
48+
}
4449
}
4550
}

src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,23 @@ namespace Tensorflow.Keras.Engine
88
public class CombinerPreprocessingLayer : Layer
99
{
1010
PreprocessingLayerArgs args;
11+
protected ICombiner combiner;
12+
protected bool _previously_updated;
1113

1214
public CombinerPreprocessingLayer(PreprocessingLayerArgs args)
1315
: base(args)
1416
{
15-
17+
_previously_updated = false;
18+
}
19+
20+
public virtual void adapt(IDatasetV2 data, bool reset_state = true)
21+
{
22+
IAccumulator accumulator;
23+
if (!reset_state)
24+
accumulator = combiner.Restore();
25+
26+
var next_data = data.make_one_shot_iterator();
27+
var data_element = next_data.next();
1628
}
1729
}
1830
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Engine
6+
{
7+
public interface IAccumulator
8+
{
9+
}
10+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Engine
6+
{
7+
/// <summary>
8+
/// Functional object that defines a shardable computation.
9+
/// </summary>
10+
public interface ICombiner
11+
{
12+
void Compute(Tensor values, IAccumulator accumulator = null);
13+
void Merge();
14+
void Extract();
15+
IAccumulator Restore();
16+
void Serialize();
17+
void Deserialize();
18+
}
19+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
public class IndexLookup : CombinerPreprocessingLayer
10+
{
11+
public IndexLookup(int max_tokens = -1,
12+
int num_oov_indices = 1,
13+
string mask_token = "",
14+
string oov_token = "[UNK]",
15+
string encoding = "utf-8",
16+
bool invert = false) : base(new PreprocessingLayerArgs())
17+
{
18+
var num_mask_tokens = mask_token == null ? 0 : 1;
19+
var vocab_size = max_tokens - (num_oov_indices + num_mask_tokens);
20+
combiner = new IndexLookupCombiner(vocab_size, mask_token);
21+
}
22+
23+
public override void adapt(IDatasetV2 data, bool reset_state = true)
24+
{
25+
if (!reset_state)
26+
throw new ValueError("IndexLookup does not support streaming adapts.");
27+
base.adapt(data, reset_state);
28+
}
29+
}
30+
}

0 commit comments

Comments
 (0)