Skip to content

Commit a1ebd70

Browse files
committed
RowPartition and RaggedTensor
1 parent bbc2e98 commit a1ebd70

File tree

13 files changed

+234
-85
lines changed

13 files changed

+234
-85
lines changed

src/TensorFlowNET.Core/APIs/tf.sparse.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
1718
using Tensorflow.Framework;
1819

1920
namespace Tensorflow
2021
{
2122
public partial class tensorflow
2223
{
23-
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, long[] dense_shape)
24-
=> new SparseTensor<T>(indices, values, dense_shape);
24+
public SparseTensor SparseTensor(long[,] indices, Array values, long[] dense_shape)
25+
=> new SparseTensor(indices, values, dense_shape);
2526

26-
public Tensor sparse_tensor_to_dense<T>(SparseTensor<T> sp_input,
27-
T default_value = default,
27+
public Tensor sparse_tensor_to_dense(SparseTensor sp_input,
28+
Array default_value = default,
2829
bool validate_indices = true,
2930
string name = null)
3031
=> gen_sparse_ops.sparse_to_dense(sp_input.indices,

src/TensorFlowNET.Core/APIs/tf.strings.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Framework;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
@@ -65,7 +67,7 @@ public Tensor substr(string input, int pos, int len,
6567
string name = null, string @uint = "BYTE")
6668
=> ops.substr(input, pos, len, @uint: @uint, name: name);
6769

68-
public Tensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
70+
public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
6971
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
7072
}
7173
}

src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs

Lines changed: 0 additions & 63 deletions
This file was deleted.

src/TensorFlowNET.Core/Framework/tensor_shape.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ bool _shape_is_compatible_0dim(Shape _this, Shape _other)
4444
return true;
4545
}
4646

47-
if (other.is_sparse())
47+
if (other.IsSparseTensor)
4848
{
4949
return self.dtype.is_compatible_with(other.dtype);
5050
}
5151

5252
return self.dtype.is_compatible_with(other.dtype) &&
5353
_shape_is_compatible_0dim(self.shape, other.shape) &&
54-
!self.is_sparse();
54+
!self.IsSparseTensor;
5555
}
5656

5757
public static Dimension dimension_at_index(TensorShape shape, int index)

src/TensorFlowNET.Core/Operations/string_ops.cs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Framework;
1718
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow
@@ -42,9 +43,25 @@ public Tensor substr<T>(T input, int pos, int len,
4243
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
4344
.SetAttributes(new { unit = @uint }));
4445

45-
public Tensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
46+
public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
4647
{
47-
return null;
48+
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
49+
{
50+
var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
51+
var result = tf.Context.ExecuteOp("StringSplitV2", name,
52+
new ExecuteOpArgs(input, sep)
53+
{
54+
GetGradientAttrs = op => new
55+
{
56+
maxsplit = op.get_attr<int>("maxsplit")
57+
}
58+
}.SetAttributes(new { maxsplit }));
59+
var (indices, values, shape) = (result[0], result[1], result[2]);
60+
indices.set_shape(new TensorShape(-1, 2));
61+
values.set_shape(new TensorShape(-1));
62+
shape.set_shape(new TensorShape(2));
63+
return new SparseTensor(indices, values, shape);
64+
});
4865
}
4966
}
5067
}

src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace Tensorflow
99
{
10-
public class EagerTensorV2 : DisposableObject, ITensor
10+
public class EagerTensorV2 : DisposableObject
1111
{
1212
SafeTensorHandleHandle EagerTensorHandle;
1313
public string Device

src/TensorFlowNET.Core/Tensors/ITensor.cs

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*****************************************************************************
2+
Copyright 2021 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Collections.Generic;
19+
using System.Text;
20+
using Tensorflow.Framework;
21+
using static Tensorflow.Binding;
22+
23+
namespace Tensorflow
24+
{
25+
/// <summary>
26+
/// Represents a ragged tensor.
27+
/// </summary>
28+
public class RaggedTensor : CompositeTensor
29+
{
30+
public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true)
31+
{
32+
33+
}
34+
35+
/// <summary>
36+
/// Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
37+
/// </summary>
38+
/// <param name="values"></param>
39+
/// <param name="value_rowids"></param>
40+
/// <param name="nrows"></param>
41+
/// <param name="name"></param>
42+
/// <param name="validate"></param>
43+
/// <returns></returns>
44+
public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids,
45+
Tensor nrows = null, string name = null, bool validate = true)
46+
{
47+
return tf_with(ops.name_scope(name, "RaggedFromValueRowIds"), scope =>
48+
{
49+
var row_partition = RowPartition.from_value_rowids(value_rowids,
50+
nrows: nrows,
51+
validate: validate);
52+
return new RaggedTensor(values, row_partition, validate: validate);
53+
});
54+
}
55+
}
56+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*****************************************************************************
2+
Copyright 2021 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Collections.Generic;
19+
using System.Text;
20+
using Tensorflow.Framework;
21+
using static Tensorflow.Binding;
22+
23+
namespace Tensorflow
24+
{
25+
/// <summary>
26+
/// Partitioning of a sequence of values into contiguous subsequences ("rows").
27+
/// </summary>
28+
public class RowPartition : CompositeTensor
29+
{
30+
public RowPartition(Tensor row_splits,
31+
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null,
32+
Tensor uniform_row_length = null)
33+
{
34+
35+
}
36+
37+
/// <summary>
38+
/// Creates a `RowPartition` with rows partitioned by `value_rowids`.
39+
/// </summary>
40+
/// <param name="value_rowids"></param>
41+
/// <param name="nrows"></param>
42+
/// <param name="validate"></param>
43+
/// <param name="preferred_dtype"></param>
44+
/// <returns></returns>
45+
public static RowPartition from_value_rowids(Tensor value_rowids,
46+
Tensor nrows = null, bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
47+
{
48+
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope =>
49+
{
50+
Tensor row_lengths = null;
51+
Tensor row_splits = null;
52+
return new RowPartition(row_splits,
53+
row_lengths: row_lengths,
54+
value_rowids: value_rowids,
55+
nrows: nrows);
56+
});
57+
}
58+
}
59+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*****************************************************************************
2+
Copyright 2021 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Linq;
19+
using Tensorflow.Framework;
20+
using static Tensorflow.Binding;
21+
22+
namespace Tensorflow
23+
{
24+
/// <summary>
25+
/// Represents a sparse tensor.
26+
/// </summary>
27+
public class SparseTensor : CompositeTensor
28+
{
29+
public Tensor indices;
30+
31+
public Tensor values;
32+
33+
public Tensor dense_shape;
34+
35+
public SparseTensor(Tensor indices, Tensor values, Tensor dense_shape)
36+
{
37+
this.indices = indices;
38+
this.values = values;
39+
this.dense_shape = dense_shape;
40+
_init();
41+
}
42+
43+
public SparseTensor(long[,] indices_, Array values_, long[] dense_shape_)
44+
{
45+
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
46+
{
47+
indices = ops.convert_to_tensor(
48+
indices_, name: "indices", dtype: dtypes.int64);
49+
values = ops.convert_to_tensor(values_, name: "values");
50+
dense_shape = ops.convert_to_tensor(
51+
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
52+
});
53+
_init();
54+
}
55+
56+
void _init()
57+
{
58+
var indices_shape = indices.TensorShape.with_rank(2);
59+
var values_shape = values.TensorShape.with_rank(1);
60+
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);
61+
62+
indices_shape["0"].merge_with(values_shape[0]);
63+
indices_shape["1"].merge_with(dense_shape_shape[0]);
64+
}
65+
66+
public static implicit operator Tensor(SparseTensor indexedSlices)
67+
{
68+
return indexedSlices.values;
69+
}
70+
71+
public static implicit operator SparseTensor(Tensor tensor)
72+
{
73+
return tensor.Tag as SparseTensor;
74+
}
75+
}
76+
}

0 commit comments

Comments
 (0)