Skip to content

Commit 5599215

Browse files
committed
Fix string_split_v2 return RaggedTensor.
1 parent a1ebd70 commit 5599215

File tree

10 files changed

+144
-26
lines changed

10 files changed

+144
-26
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ public Tensor log(Tensor x, string name = null)
3232
/// <returns></returns>
3333
public Tensor erf(Tensor x, string name = null)
3434
=> math_ops.erf(x, name);
35+
36+
/// <summary>
37+
///
38+
/// </summary>
39+
/// <param name="arr"></param>
40+
/// <param name="weights"></param>
41+
/// <param name="minlength"></param>
42+
/// <param name="maxlength"></param>
43+
/// <param name="dtype"></param>
44+
/// <param name="name"></param>
45+
/// <param name="axis"></param>
46+
/// <param name="binary_output"></param>
47+
/// <returns></returns>
48+
public Tensor bincount(Tensor arr, Tensor weights = null,
49+
Tensor minlength = null,
50+
Tensor maxlength = null,
51+
TF_DataType dtype = TF_DataType.TF_INT32,
52+
string name = null,
53+
TensorShape axis = null,
54+
bool binary_output = false)
55+
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
56+
dtype: dtype, name: name, axis: axis, binary_output: binary_output);
3557
}
3658

3759
public Tensor abs(Tensor x, string name = null)

src/TensorFlowNET.Core/APIs/tf.strings.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public Tensor substr(string input, int pos, int len,
6767
string name = null, string @uint = "BYTE")
6868
=> ops.substr(input, pos, len, @uint: @uint, name: name);
6969

70-
public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
70+
public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
7171
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
7272
}
7373
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,6 @@ public static Tensor cosh(Tensor x, string name = null)
249249
return _op.outputs[0];
250250
}
251251

252-
public static Tensor cumsum<T>(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null)
253-
{
254-
var _op = tf.OpDefLib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });
255-
256-
return _op.outputs[0];
257-
}
258-
259252
/// <summary>
260253
/// Computes the sum along segments of a tensor.
261254
/// </summary>

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,12 @@ public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, st
168168
}
169169

170170
public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
171-
{
172-
return tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
173-
{
174-
name = scope;
175-
x = ops.convert_to_tensor(x, name: "x");
176-
177-
return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
178-
});
179-
}
171+
=> tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
172+
{
173+
name = scope;
174+
return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis)
175+
.SetAttributes(new { exclusive, reverse }));
176+
});
180177

181178
/// <summary>
182179
/// Computes Psi, the derivative of Lgamma (the log of the absolute value of
@@ -807,6 +804,31 @@ public static Tensor batch_matmul(Tensor x, Tensor y,
807804
.SetAttributes(new { adj_x, adj_y }));
808805
});
809806

807+
public static Tensor bincount(Tensor arr, Tensor weights = null,
808+
Tensor minlength = null,
809+
Tensor maxlength = null,
810+
TF_DataType dtype = TF_DataType.TF_INT32,
811+
string name = null,
812+
TensorShape axis = null,
813+
bool binary_output = false)
814+
=> tf_with(ops.name_scope(name, "bincount"), scope =>
815+
{
816+
name = scope;
817+
if(!binary_output && axis == null)
818+
{
819+
var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0;
820+
var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1);
821+
if (minlength != null)
822+
output_size = math_ops.maximum(minlength, output_size);
823+
if (maxlength != null)
824+
output_size = math_ops.minimum(maxlength, output_size);
825+
var weights = constant_op.constant(new long[0], dtype: dtype);
826+
return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights));
827+
}
828+
829+
throw new NotImplementedException("");
830+
});
831+
810832
/// <summary>
811833
/// Returns the complex conjugate of a complex number.
812834
/// </summary>

src/TensorFlowNET.Core/Operations/string_ops.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using NumSharp;
1718
using Tensorflow.Framework;
1819
using static Tensorflow.Binding;
1920

@@ -43,7 +44,7 @@ public Tensor substr<T>(T input, int pos, int len,
4344
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
4445
.SetAttributes(new { unit = @uint }));
4546

46-
public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
47+
public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
4748
{
4849
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
4950
{
@@ -60,7 +61,12 @@ public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit
6061
indices.set_shape(new TensorShape(-1, 2));
6162
values.set_shape(new TensorShape(-1));
6263
shape.set_shape(new TensorShape(2));
63-
return new SparseTensor(indices, values, shape);
64+
65+
var sparse_result = new SparseTensor(indices, values, shape);
66+
return RaggedTensor.from_value_rowids(sparse_result.values,
67+
value_rowids: sparse_result.indices[Slice.All, 0],
68+
nrows: sparse_result.dense_shape[0],
69+
validate: false);
6470
});
6571
}
6672
}

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes
5050
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
5151
<DefineConstants>TRACE;DEBUG</DefineConstants>
5252
<PlatformTarget>x64</PlatformTarget>
53+
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
5354
</PropertyGroup>
5455

5556
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">

src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Collections.Generic;
1919
using System.Text;
20+
using System.Linq;
2021
using Tensorflow.Framework;
2122
using static Tensorflow.Binding;
2223

@@ -27,9 +28,30 @@ namespace Tensorflow
2728
/// </summary>
2829
public class RaggedTensor : CompositeTensor
2930
{
30-
public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true)
31+
Tensor _values;
32+
RowPartition _row_partition;
33+
public TF_DataType dtype => _values.dtype;
34+
public TensorShape shape
3135
{
36+
get
37+
{
38+
var nrows = _row_partition.static_nrows;
39+
var ncols = _row_partition.static_uniform_row_length;
40+
return new TensorShape(nrows, ncols);
41+
}
42+
}
3243

44+
public RaggedTensor(Tensor values,
45+
bool @internal = true,
46+
RowPartition row_partition = null)
47+
{
48+
_values = values;
49+
_row_partition = row_partition;
50+
}
51+
52+
public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true)
53+
{
54+
return new RaggedTensor(values, @internal: true, row_partition: row_partition);
3355
}
3456

3557
/// <summary>
@@ -49,8 +71,21 @@ public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids,
4971
var row_partition = RowPartition.from_value_rowids(value_rowids,
5072
nrows: nrows,
5173
validate: validate);
52-
return new RaggedTensor(values, row_partition, validate: validate);
74+
return from_row_partition(values, row_partition, validate: validate);
5375
});
5476
}
77+
78+
public override string ToString()
79+
=> $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";
80+
81+
public static implicit operator Tensor(RaggedTensor indexedSlices)
82+
{
83+
return indexedSlices._values;
84+
}
85+
86+
public static implicit operator RaggedTensor(Tensor tensor)
87+
{
88+
return tensor.Tag as RaggedTensor;
89+
}
5590
}
5691
}

src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,35 @@ namespace Tensorflow
2727
/// </summary>
2828
public class RowPartition : CompositeTensor
2929
{
30+
Tensor _row_splits;
31+
Tensor _row_lengths;
32+
Tensor _value_rowids;
33+
Tensor _nrows;
34+
35+
public int static_nrows
36+
{
37+
get
38+
{
39+
return _row_splits.shape[0] - 1;
40+
}
41+
}
42+
43+
public int static_uniform_row_length
44+
{
45+
get
46+
{
47+
return -1;
48+
}
49+
}
50+
3051
public RowPartition(Tensor row_splits,
3152
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null,
3253
Tensor uniform_row_length = null)
3354
{
34-
55+
_row_splits = row_splits;
56+
_row_lengths = row_lengths;
57+
_value_rowids = value_rowids;
58+
_nrows = nrows;
3559
}
3660

3761
/// <summary>
@@ -47,8 +71,18 @@ public static RowPartition from_value_rowids(Tensor value_rowids,
4771
{
4872
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope =>
4973
{
50-
Tensor row_lengths = null;
51-
Tensor row_splits = null;
74+
var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32);
75+
var nrows_int32 = math_ops.cast(nrows, dtypes.int32);
76+
var row_lengths = tf.math.bincount(value_rowids_int32,
77+
minlength: nrows_int32,
78+
maxlength: nrows_int32,
79+
dtype: value_rowids.dtype);
80+
var row_splits = array_ops.concat(new object[]
81+
{
82+
ops.convert_to_tensor(new long[] { 0 }),
83+
tf.cumsum(row_lengths)
84+
}, axis: 0);
85+
5286
return new RowPartition(row_splits,
5387
row_lengths: row_lengths,
5488
value_rowids: value_rowids,

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
4949
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
5050
</PropertyGroup>
5151

52+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
53+
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile>
54+
</PropertyGroup>
55+
5256
<ItemGroup>
5357
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
5458
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />

test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ public void StringArray()
6262
[TestMethod]
6363
public void StringSplit()
6464
{
65-
var tensor = tf.constant(new[] { "hello world", "tensorflow .net" });
66-
tf.strings.split(tensor);
65+
var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" });
66+
var ragged_tensor = tf.strings.split(tensor);
67+
Assert.AreEqual((3, -1), ragged_tensor.shape);
6768
}
6869
}
6970
}

0 commit comments

Comments
 (0)