Skip to content

Commit dd1b589

Browse files
committed
Tensor.Flatten
1 parent e2190c9 commit dd1b589

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class Tensor
9+
{
10+
public object[] Flatten()
11+
{
12+
return new Tensor[] { this };
13+
}
14+
}
15+
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ namespace Tensorflow
3939
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
4040
/// </summary>
4141
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
42-
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray, IPackable<Tensor>
42+
public partial class Tensor : DisposableObject,
43+
ITensorOrOperation,
44+
_TensorLike,
45+
ITensorOrTensorArray,
46+
IPackable<Tensor>,
47+
ICanBeFlattened
4348
{
4449
private readonly int _id;
4550
private readonly Operation _op;

src/TensorFlowNET.Core/Tensors/TensorArray.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,11 @@ public TensorArray unstack(Tensor value, string name = null)
6161

6262
public Tensor read(Tensor index, string name = null)
6363
=> _implementation.read(index, name: name);
64+
65+
public TensorArray write(Tensor index, Tensor value, string name = null)
66+
=> _implementation.write(index, value, name: name);
67+
68+
public Tensor stack(string name = null)
69+
=> _implementation.stack(name: name);
6470
}
6571
}

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public static class dtypes
3333
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
3434
public static TF_DataType float16 = TF_DataType.TF_HALF;
3535
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
36+
public static TF_DataType complex = TF_DataType.TF_COMPLEX;
37+
public static TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
38+
public static TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
3639
public static TF_DataType variant = TF_DataType.TF_VARIANT;
3740
public static TF_DataType resource = TF_DataType.TF_RESOURCE;
3841

0 commit comments

Comments
 (0)