Skip to content

Commit 1f67488

Browse files
committed
Refactor string tensor.
1 parent 48d96f4 commit 1f67488

File tree

11 files changed

+150
-211
lines changed

11 files changed

+150
-211
lines changed

src/TensorFlowNET.Console/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static void Main(string[] args)
2525

2626
FuncGraph(mm);
2727

28-
// 85M
28+
// 65M
2929
Console.WriteLine("Finished.");
3030
Console.ReadLine();
3131
}

src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -29,66 +29,6 @@ namespace Tensorflow
2929
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
3030
public partial class Tensor
3131
{
32-
public T ToScalar<T>()
33-
{
34-
unsafe
35-
{
36-
if (typeof(T).as_dtype() == this.dtype && this.dtype != TF_DataType.TF_STRING)
37-
return Unsafe.Read<T>(this.buffer.ToPointer());
38-
39-
switch (this.dtype)
40-
{
41-
#if _REGEN
42-
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
43-
case TF_DataType.#1:
44-
return Converts.ChangeType<T>(*(#3*) this.buffer);
45-
%
46-
#else
47-
48-
case TF_DataType.TF_UINT8:
49-
return Converts.ChangeType<T>(*(byte*)this.buffer);
50-
case TF_DataType.TF_INT16:
51-
return Converts.ChangeType<T>(*(short*)this.buffer);
52-
case TF_DataType.TF_UINT16:
53-
return Converts.ChangeType<T>(*(ushort*)this.buffer);
54-
case TF_DataType.TF_INT32:
55-
return Converts.ChangeType<T>(*(int*)this.buffer);
56-
case TF_DataType.TF_UINT32:
57-
return Converts.ChangeType<T>(*(uint*)this.buffer);
58-
case TF_DataType.TF_INT64:
59-
return Converts.ChangeType<T>(*(long*)this.buffer);
60-
case TF_DataType.TF_UINT64:
61-
return Converts.ChangeType<T>(*(ulong*)this.buffer);
62-
case TF_DataType.TF_DOUBLE:
63-
return Converts.ChangeType<T>(*(double*)this.buffer);
64-
case TF_DataType.TF_FLOAT:
65-
return Converts.ChangeType<T>(*(float*)this.buffer);
66-
#endif
67-
case TF_DataType.TF_STRING:
68-
if (this.NDims != 0)
69-
throw new ArgumentException($"{nameof(Tensor)} can only be scalar.");
70-
71-
IntPtr stringStartAddress = IntPtr.Zero;
72-
ulong dstLen = 0;
73-
74-
c_api.TF_StringDecode((byte*)this.buffer + 8, this.bytesize, (byte**)&stringStartAddress, ref dstLen, tf.Status.Handle);
75-
tf.Status.Check(true);
76-
77-
var dstLenInt = checked((int)dstLen);
78-
var value = Encoding.UTF8.GetString((byte*)stringStartAddress, dstLenInt);
79-
if (typeof(T) == typeof(string))
80-
return (T)(object)value;
81-
else
82-
return Converts.ChangeType<T>(value);
83-
84-
case TF_DataType.TF_COMPLEX64:
85-
case TF_DataType.TF_COMPLEX128:
86-
default:
87-
throw new NotSupportedException();
88-
}
89-
}
90-
}
91-
9232
public unsafe void CopyTo(NDArray nd)
9333
{
9434
if (!nd.Shape.IsContiguous)

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 13 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -457,53 +457,15 @@ public unsafe Tensor(Complex value, TF_DataType? dType = null)
457457
/// </summary>
458458
public unsafe Tensor(string str)
459459
{
460-
var buffer = Encoding.UTF8.GetBytes(str);
461-
var size = c_api.TF_StringEncodedSize((ulong)buffer.Length);
462-
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, null, 0, size + sizeof(ulong));
463-
AllocationType = AllocationType.Tensorflow;
464-
465-
IntPtr tensor = c_api.TF_TensorData(handle);
466-
Marshal.WriteInt64(tensor, 0);
467-
fixed (byte* src = buffer)
468-
c_api.TF_StringEncode(src, (ulong)buffer.Length, (byte*)(tensor + sizeof(long)), size, tf.Status.Handle);
469-
_handle = handle;
470-
tf.Status.Check(true);
460+
_handle = StringTensor(new string[] { str }, TensorShape.Scalar);
461+
#if TRACK_TENSOR_LIFE
462+
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
463+
#endif
471464
}
472465

473466
public unsafe Tensor(string[] strings)
474467
{
475-
// convert string array to byte[][]
476-
var buffer = new byte[strings.Length][];
477-
for (var i = 0; i < strings.Length; i++)
478-
buffer[i] = Encoding.UTF8.GetBytes(strings[i]);
479-
long[] shape = new long[] { strings.Length };
480-
481-
ulong size = 0;
482-
foreach (var b in buffer)
483-
size += TF_StringEncodedSize((ulong)b.Length);
484-
485-
ulong src_size = size + (ulong)buffer.Length * sizeof(ulong);
486-
_handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, src_size);
487-
AllocationType = AllocationType.Tensorflow;
488-
489-
// Clear offset table
490-
IntPtr input = TensorDataPointer;
491-
IntPtr data_start = input + buffer.Length * sizeof(ulong);
492-
IntPtr limit = input + (int)src_size;
493-
ulong offset = 0;
494-
for (int i = 0; i < buffer.Length; i++)
495-
{
496-
Marshal.WriteInt64(input, i * sizeof(ulong), (long)offset);
497-
fixed (byte* src = &buffer[i][0])
498-
{
499-
var written = TF_StringEncode(src, (ulong)buffer[i].Length, (byte*)data_start, (ulong)(limit.ToInt64() - data_start.ToInt64()), tf.Status.Handle);
500-
tf.Status.Check(true);
501-
//input += 8;
502-
data_start += (int)written;
503-
offset += written;
504-
}
505-
}
506-
468+
_handle = StringTensor(strings, new TensorShape(strings.Length));
507469
#if TRACK_TENSOR_LIFE
508470
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
509471
#endif
@@ -515,12 +477,12 @@ public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
515477
tensorDType = nd.dtype.as_dtype();
516478

517479
// todo: handle nd of type "String" here too
518-
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
480+
/*if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
519481
{
520482
if (nd.Unsafe.Storage.Shape.IsContiguous)
521483
{
522484
var bytesLength = (ulong)nd.size;
523-
var size = c_api.TF_StringEncodedSize(bytesLength);
485+
var size = bytesLength + 1;
524486
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, null, 0, size + 8);
525487
AllocationType = AllocationType.Tensorflow;
526488
@@ -534,7 +496,7 @@ public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
534496
else
535497
{
536498
var buffer = nd.ToArray<byte>();
537-
var size = c_api.TF_StringEncodedSize((ulong)buffer.Length);
499+
var size = (ulong)buffer.Length + 1;
538500
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, null, 0, size + 8);
539501
AllocationType = AllocationType.Tensorflow;
540502
@@ -549,9 +511,12 @@ public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
549511
}
550512
551513
return;
552-
}
514+
}*/
553515

554516
CreateTensorFromNDArray(nd, tensorDType);
517+
#if TRACK_TENSOR_LIFE
518+
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} Data: 0x{TensorDataPointer.ToString("x16")}");
519+
#endif
555520
}
556521

557522
private unsafe void CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
@@ -576,10 +541,6 @@ private unsafe void CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype
576541
}
577542
else
578543
AllocationType = AllocationType.Tensorflow;
579-
580-
#if TRACK_TENSOR_LIFE
581-
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} Data: 0x{TensorDataPointer.ToString("x16")}");
582-
#endif
583544
}
584545

585546
public Tensor(Operation op, int value_index, TF_DataType dtype)
@@ -608,26 +569,10 @@ public Tensor(Operation op, int value_index, TF_DataType dtype)
608569
protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size)
609570
{
610571
if (dt == TF_DataType.TF_STRING && data is byte[] buffer)
611-
return CreateStringTensorFromBytes(buffer, shape);
572+
return StringTensor(new byte[][] { buffer }, TensorShape.Scalar);
612573
return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size);
613574
}
614575

615-
protected unsafe IntPtr CreateStringTensorFromBytes(byte[] buffer, long[] shape)
616-
{
617-
var size = c_api.TF_StringEncodedSize((ulong)buffer.Length);
618-
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, 0, size + sizeof(long));
619-
AllocationType = AllocationType.Tensorflow;
620-
621-
IntPtr tensor = c_api.TF_TensorData(handle);
622-
Marshal.WriteInt64(tensor, 0);
623-
624-
fixed (byte* src = buffer)
625-
c_api.TF_StringEncode(src, (ulong)buffer.Length, (byte*)(tensor + sizeof(long)), size, tf.Status.Handle);
626-
627-
tf.Status.Check(true);
628-
return handle;
629-
}
630-
631576
/// <summary>
632577
/// Creates a new tensor from a subsection of the given array without copying memory. The array is pinned down and the pointer passed on.
633578
/// </summary>
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
using System;
2+
using System.Linq;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow
8+
{
9+
public partial class Tensor
10+
{
11+
public unsafe IntPtr StringTensor(string[] strings, TensorShape shape)
12+
{
13+
// convert string array to byte[][]
14+
var buffer = new byte[strings.Length][];
15+
for (var i = 0; i < strings.Length; i++)
16+
buffer[i] = Encoding.UTF8.GetBytes(strings[i]);
17+
18+
return StringTensor(buffer, shape);
19+
}
20+
21+
public unsafe IntPtr StringTensor(byte[][] buffer, TensorShape shape)
22+
{
23+
ulong size = 0;
24+
foreach (var b in buffer)
25+
size += c_api.TF_StringEncodedSize((ulong)b.Length);
26+
27+
var src_size = size + (ulong)buffer.Length * sizeof(ulong);
28+
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
29+
shape.dims.Select(x => (long)x).ToArray(),
30+
shape.ndim,
31+
src_size);
32+
AllocationType = AllocationType.Tensorflow;
33+
34+
IntPtr data_start = c_api.TF_TensorData(handle);
35+
IntPtr string_start = data_start + buffer.Length * sizeof(ulong);
36+
IntPtr limit = data_start + (int)src_size;
37+
ulong offset = 0;
38+
for (int i = 0; i < buffer.Length; i++)
39+
{
40+
Marshal.WriteInt64(data_start, i * sizeof(ulong), (long)offset);
41+
if (buffer[i].Length == 0)
42+
{
43+
Marshal.WriteByte(string_start, 0);
44+
break;
45+
}
46+
47+
fixed (byte* src = &buffer[i][0])
48+
{
49+
/*Marshal.WriteByte(string_start, Convert.ToByte(buffer[i].Length));
50+
tf.memcpy((string_start + 1).ToPointer(), src, (ulong)buffer[i].Length);
51+
string_start += buffer[i].Length + 1;
52+
offset += buffer[i].Length + 1;*/
53+
54+
var written = c_api.TF_StringEncode(src, (ulong)buffer[i].Length, (byte*)string_start, (ulong)(limit.ToInt64() - string_start.ToInt64()), tf.Status.Handle);
55+
tf.Status.Check(true);
56+
string_start += (int)written;
57+
offset += written;
58+
}
59+
}
60+
61+
return handle;
62+
}
63+
64+
/// <summary>
65+
/// Extracts string array from current Tensor.
66+
/// </summary>
67+
/// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception>
68+
public unsafe string[] StringData()
69+
{
70+
var buffer = StringBytes();
71+
72+
var _str = new string[buffer.Length];
73+
for (int i = 0; i < _str.Length; i++)
74+
_str[i] = Encoding.UTF8.GetString(buffer[i]);
75+
76+
return _str;
77+
}
78+
79+
public unsafe byte[][] StringBytes()
80+
{
81+
if (dtype != TF_DataType.TF_STRING)
82+
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");
83+
84+
//
85+
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
86+
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
87+
//
88+
long size = 1;
89+
foreach (var s in TensorShape.dims)
90+
size *= s;
91+
92+
var buffer = new byte[size][];
93+
var data_start = c_api.TF_TensorData(_handle);
94+
var string_start = data_start + (int)(size * sizeof(ulong));
95+
for (int i = 0; i < buffer.Length; i++)
96+
{
97+
var len = *(byte*)string_start;
98+
buffer[i] = new byte[len];
99+
string_start += 1;
100+
Marshal.Copy(string_start, buffer[i], 0, len);
101+
string_start += len;
102+
}
103+
104+
return buffer;
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)