Skip to content

Commit 0142174

Browse files
committed
Add tf.string.string_length.
1 parent 5599215 commit 0142174

File tree

7 files changed

+179
-8
lines changed

7 files changed

+179
-8
lines changed

src/TensorFlowNET.Core/APIs/tf.strings.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,26 @@ public Tensor substr(string input, int pos, int len,
6767
string name = null, string @uint = "BYTE")
6868
=> ops.substr(input, pos, len, @uint: @uint, name: name);
6969

70+
/// <summary>
71+
/// String lengths of `input`.
72+
/// </summary>
73+
/// <param name="input"></param>
74+
/// <param name="name"></param>
75+
/// <param name="unit"></param>
76+
/// <returns></returns>
77+
public Tensor string_length(Tensor input, string name = null, string unit = "BYTE")
78+
=> ops.string_length(input, name: name, unit: unit);
79+
7080
public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
7181
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
82+
83+
public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding,
84+
string errors = "replace", int replacement_char = 0xFFFD,
85+
bool replace_control_characters = false, string name = null)
86+
=> ops.unicode_decode_with_offsets(input, input_encoding, errors,
87+
replacement_char: replacement_char,
88+
replace_control_characters: replace_control_characters,
89+
name: name);
7290
}
7391
}
7492
}

src/TensorFlowNET.Core/Operations/string_ops.cs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,22 @@ public Tensor substr<T>(T input, int pos, int len,
4444
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
4545
.SetAttributes(new { unit = @uint }));
4646

47+
/// <summary>
48+
/// Computes the length of each string given in the input tensor.
49+
/// </summary>
50+
/// <param name="input"></param>
51+
/// <param name="name"></param>
52+
/// <param name="unit"></param>
53+
/// <returns></returns>
54+
public Tensor string_length(Tensor input, string name = null, string unit = "BYTE")
55+
=> tf.Context.ExecuteOp("StringLength", name, new ExecuteOpArgs(input)
56+
{
57+
GetGradientAttrs = op => new
58+
{
59+
unit = op.get_attr<string>("unit")
60+
}
61+
}.SetAttributes(new { unit }));
62+
4763
public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
4864
{
4965
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
@@ -69,5 +85,49 @@ public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit
6985
validate: false);
7086
});
7187
}
88+
89+
public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, string errors,
90+
int replacement_char = 0xFFFD, bool replace_control_characters = false, string name = null)
91+
{
92+
return tf_with(ops.name_scope(name, "UnicodeDecodeWithOffsets"), scope =>
93+
{
94+
var (codepoints, byte_start_offsets) = _unicode_decode(input, input_encoding, errors,
95+
replacement_char, replace_control_characters,
96+
with_offsets: true, name: name);
97+
return (codepoints, byte_start_offsets);
98+
});
99+
}
100+
101+
(RaggedTensor, RaggedTensor) _unicode_decode(Tensor input, string input_encoding, string errors, int replacement_char,
102+
bool replace_control_characters, bool with_offsets, string name = null)
103+
{
104+
if (with_offsets)
105+
{
106+
var flat_result = tf.Context.ExecuteOp("UnicodeDecodeWithOffsets", name, new ExecuteOpArgs(input)
107+
{
108+
GetGradientAttrs = op => new
109+
{
110+
input_encoding = op.get_attr<string>("input_encoding"),
111+
errors = op.get_attr<string>("errors"),
112+
replacement_char = op.get_attr<int>("replacement_char"),
113+
replace_control_characters = op.get_attr<bool>("replace_control_characters"),
114+
Tsplits = op.get_attr<TF_DataType>("Tsplits")
115+
}
116+
}.SetAttributes(new
117+
{
118+
input_encoding,
119+
errors,
120+
replacement_char,
121+
replace_control_characters
122+
}));
123+
124+
var codepoints = RaggedTensor.from_row_splits(flat_result[1], flat_result[0], validate: false);
125+
126+
var offsets = RaggedTensor.from_row_splits(flat_result[2], flat_result[0], validate: false);
127+
return (codepoints, offsets);
128+
}
129+
130+
return (null, null);
131+
}
72132
}
73133
}

src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using System.Linq;
2121
using Tensorflow.Framework;
2222
using static Tensorflow.Binding;
23+
using NumSharp;
2324

2425
namespace Tensorflow
2526
{
@@ -30,6 +31,8 @@ public class RaggedTensor : CompositeTensor
3031
{
3132
Tensor _values;
3233
RowPartition _row_partition;
34+
Tensor _row_splits => _row_partition.row_splits;
35+
3336
public TF_DataType dtype => _values.dtype;
3437
public TensorShape shape
3538
{
@@ -41,6 +44,28 @@ public TensorShape shape
4144
}
4245
}
4346

47+
public RaggedTensor this[params Slice[] slices]
48+
{
49+
get
50+
{
51+
var row_key = slices[0];
52+
var inner_keys = slices.Skip(1).ToArray();
53+
54+
var args = tensor_util.ParseSlices(slices);
55+
56+
return tf_with(ops.name_scope(null, "RaggedGetItem", args), scope =>
57+
{
58+
string name = scope;
59+
return _ragged_getitem_inner_dimensions(this, inner_keys);
60+
});
61+
}
62+
}
63+
64+
RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices)
65+
{
66+
return input;
67+
}
68+
4469
public RaggedTensor(Tensor values,
4570
bool @internal = true,
4671
RowPartition row_partition = null)
@@ -75,13 +100,44 @@ public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids,
75100
});
76101
}
77102

103+
public static RaggedTensor from_row_splits(Tensor values, Tensor row_splits,
104+
string name = null, bool validate = true)
105+
{
106+
return tf_with(ops.name_scope(name, "RaggedFromRowSplits"), scope =>
107+
{
108+
var row_partition = RowPartition.from_row_splits(row_splits,
109+
validate: validate);
110+
return from_row_partition(values, row_partition, validate: validate);
111+
});
112+
}
113+
114+
Tensor _to_variant(bool batched_input = false, string name = null)
115+
=> tf_with(ops.name_scope(name, "RaggedToVariant"), scope =>
116+
{
117+
return tf.Context.ExecuteOp("RaggedTensorToVariant", name,
118+
new ExecuteOpArgs(nested_row_splits, flat_values)
119+
{
120+
GetGradientAttrs = op => new
121+
{
122+
RAGGED_RANK = op.get_attr<int>("RAGGED_RANK"),
123+
Tvalues = op.get_attr<TF_DataType>("Tvalues"),
124+
Tsplits = op.get_attr<TF_DataType>("Tsplits"),
125+
batched_input = op.get_attr<bool>("batched_input")
126+
}
127+
}.SetAttributes(new { batched_input }));
128+
});
129+
130+
Tensor flat_values
131+
=> _values;
132+
133+
Tensor[] nested_row_splits
134+
=> new[] { _row_splits };
135+
78136
public override string ToString()
79137
=> $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";
80138

81139
public static implicit operator Tensor(RaggedTensor indexedSlices)
82-
{
83-
return indexedSlices._values;
84-
}
140+
=> indexedSlices._to_variant();
85141

86142
public static implicit operator RaggedTensor(Tensor tensor)
87143
{

src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace Tensorflow
2828
public class RowPartition : CompositeTensor
2929
{
3030
Tensor _row_splits;
31+
public Tensor row_splits => _row_splits;
3132
Tensor _row_lengths;
3233
Tensor _value_rowids;
3334
Tensor _nrows;
@@ -89,5 +90,14 @@ public static RowPartition from_value_rowids(Tensor value_rowids,
8990
nrows: nrows);
9091
});
9192
}
93+
94+
public static RowPartition from_row_splits(Tensor row_splits,
95+
bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
96+
{
97+
return tf_with(ops.name_scope(null, "RowPartitionFromRowSplits"), scope =>
98+
{
99+
return new RowPartition(row_splits);
100+
});
101+
}
92102
}
93103
}

src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ Tensors _preprocess(Tensors inputs)
5555
if (inputs.shape.ndim > 1)
5656
input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 });
5757
if (args.Split == "whitespace")
58-
input_tensor = tf.strings.split(inputs);
59-
58+
input_tensor = tf.strings.split(input_tensor);
6059
}
61-
return inputs;
60+
return input_tensor;
6261
}
6362
}
6463
}

src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
5+
using static Tensorflow.Binding;
46

57
namespace Tensorflow.Text.Tokenizers
68
{
@@ -13,7 +15,31 @@ public class WhitespaceTokenizer : ITokenizer
1315
/// <returns></returns>
1416
public Tensor tokenize(Tensor input)
1517
{
18+
tokenize_with_offsets(input);
1619
throw new NotImplementedException("");
1720
}
21+
22+
Tensor[] tokenize_with_offsets(Tensor input)
23+
{
24+
tf_with(ops.name_scope(null, "WhitespaceTokenize"), scope =>
25+
{
26+
_whitespace_tokenize_with_offsets_encode_decode_wrapper(input);
27+
});
28+
throw new NotImplementedException("");
29+
}
30+
31+
Tensor _whitespace_tokenize_with_offsets_encode_decode_wrapper(Tensor input_tensor)
32+
{
33+
// Decode the strings and get byte offsets
34+
var (codepoints, byte_start_offsets) = tf.strings.unicode_decode_with_offsets(input_tensor, "UTF-8");
35+
var byte_end_offsets = array_ops.concat(new Tensor[]
36+
{
37+
byte_start_offsets[Slice.All, new Slice(1)],
38+
math_ops.cast(
39+
array_ops.expand_dims(tf.strings.string_length(input_tensor), 1),
40+
dtypes.int64)
41+
}, 1);
42+
return input_tensor;
43+
}
1844
}
1945
}

test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ namespace TensorFlowNET.UnitTest.Text
1010
[TestClass]
1111
public class TokenizerTest
1212
{
13-
[TestMethod]
13+
[TestMethod, Ignore]
1414
public void Tokenize()
1515
{
1616
var docs = tf.constant(new[] { "Everything not saved will be lost." });
17+
var tokenizer = text.WhitespaceTokenizer();
18+
var tokens = tokenizer.tokenize(docs);
1719
}
1820
}
1921
}

0 commit comments

Comments
 (0)