Skip to content

Commit 48d96f4

Browse files
committed
return weights for load_weights.
1 parent c836fd6 commit 48d96f4

File tree

9 files changed

+40
-207
lines changed

9 files changed

+40
-207
lines changed

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Tensorflow.Eager
77
{
88
public partial class EagerTensor
99
{
10-
public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero)
10+
public EagerTensor(SafeTensorHandleHandle handle)
1111
{
1212
_id = ops.uid();
1313
EagerTensorHandle = handle;

src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static Operation assign_variable_op(Tensor resource, Tensor value, string
6363
{
6464
if (tf.Context.executing_eagerly())
6565
{
66-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
66+
tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
6767
"AssignVariableOp", name,
6868
null,
6969
resource, value);

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 2 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -265,182 +265,8 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] f
265265

266266
private static unsafe NDArray fetchValue(IntPtr output)
267267
{
268-
NDArray ret;
269-
using (var tensor = new Tensor(output))
270-
{
271-
var ndims = tensor.shape;
272-
var srcAddress = c_api.TF_TensorData(output).ToInt64();
273-
274-
if (ndims.Length == 0)
275-
{
276-
switch (tensor.dtype)
277-
{
278-
case TF_DataType.TF_BOOL:
279-
ret = NDArray.Scalar(*(bool*)srcAddress);
280-
break;
281-
case TF_DataType.TF_STRING:
282-
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
283-
ret = new NDArray(reader.ReadBytes().ToByteArray());
284-
break;
285-
case TF_DataType.TF_UINT8:
286-
ret = NDArray.Scalar(*(byte*)srcAddress);
287-
break;
288-
case TF_DataType.TF_INT16:
289-
ret = NDArray.Scalar(*(short*)srcAddress);
290-
break;
291-
case TF_DataType.TF_INT32:
292-
ret = NDArray.Scalar(*(int*)srcAddress);
293-
break;
294-
case TF_DataType.TF_INT64:
295-
ret = NDArray.Scalar(*(long*)srcAddress);
296-
break;
297-
case TF_DataType.TF_UINT16:
298-
ret = NDArray.Scalar(*(ushort*)srcAddress);
299-
break;
300-
case TF_DataType.TF_UINT32:
301-
ret = NDArray.Scalar(*(uint*)srcAddress);
302-
break;
303-
case TF_DataType.TF_UINT64:
304-
ret = NDArray.Scalar(*(ulong*)srcAddress);
305-
break;
306-
case TF_DataType.TF_FLOAT:
307-
ret = NDArray.Scalar(*(float*)srcAddress);
308-
break;
309-
case TF_DataType.TF_DOUBLE:
310-
ret = NDArray.Scalar(*(double*)srcAddress);
311-
break;
312-
default:
313-
throw new NotImplementedException("can't fetch output");
314-
}
315-
}
316-
else
317-
{
318-
//var size = (long) tensor.size;
319-
//var itemsize = (long) tensor.itemsize;
320-
var bytesize = (long)tensor.bytesize;
321-
var src = (void*)srcAddress;
322-
323-
#if _REGEN
324-
#region Compute
325-
switch (tensor.dtype)
326-
{
327-
%foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")%
328-
case TF_DataType.#3:
329-
{
330-
ret = new NDArray(NPTypeCode.#1, ndims, false);
331-
System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize);
332-
break;
333-
}
334-
%
335-
case TF_DataType.TF_STRING:
336-
{
337-
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
338-
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
339-
ret = NDArray.FromString(reader.ReadString());
340-
break;
341-
}
342-
default:
343-
throw new NotSupportedException();
344-
}
345-
#endregion
346-
#else
347-
348-
#region Compute
349-
350-
switch (tensor.dtype)
351-
{
352-
case TF_DataType.TF_BOOL:
353-
{
354-
ret = new NDArray(NPTypeCode.Boolean, ndims, false);
355-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
356-
break;
357-
}
358-
359-
case TF_DataType.TF_UINT8:
360-
{
361-
ret = new NDArray(NPTypeCode.Byte, ndims, false);
362-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
363-
break;
364-
}
365-
366-
case TF_DataType.TF_INT16:
367-
{
368-
ret = new NDArray(NPTypeCode.Int16, ndims, false);
369-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
370-
break;
371-
}
372-
373-
case TF_DataType.TF_UINT16:
374-
{
375-
ret = new NDArray(NPTypeCode.UInt16, ndims, false);
376-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
377-
break;
378-
}
379-
380-
case TF_DataType.TF_INT32:
381-
{
382-
ret = new NDArray(NPTypeCode.Int32, ndims, false);
383-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
384-
break;
385-
}
386-
387-
case TF_DataType.TF_UINT32:
388-
{
389-
ret = new NDArray(NPTypeCode.UInt32, ndims, false);
390-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
391-
break;
392-
}
393-
394-
case TF_DataType.TF_INT64:
395-
{
396-
ret = new NDArray(NPTypeCode.Int64, ndims, false);
397-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
398-
break;
399-
}
400-
401-
case TF_DataType.TF_UINT64:
402-
{
403-
ret = new NDArray(NPTypeCode.UInt64, ndims, false);
404-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
405-
break;
406-
}
407-
408-
case TF_DataType.TF_DOUBLE:
409-
{
410-
ret = new NDArray(NPTypeCode.Double, ndims, false);
411-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
412-
break;
413-
}
414-
415-
case TF_DataType.TF_FLOAT:
416-
{
417-
ret = new NDArray(NPTypeCode.Single, ndims, false);
418-
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
419-
break;
420-
}
421-
422-
case TF_DataType.TF_STRING:
423-
{
424-
throw new NotImplementedException();
425-
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
426-
#pragma warning disable CS0162 // Unreachable code detected
427-
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
428-
#pragma warning restore CS0162 // Unreachable code detected
429-
ret = NDArray.FromString(reader.ReadString());
430-
break;
431-
}
432-
433-
default:
434-
throw new NotSupportedException();
435-
}
436-
437-
#endregion
438-
439-
#endif
440-
}
441-
}
442-
443-
return ret;
268+
var tensor = new Tensor(output);
269+
return tensor.numpy();
444270
}
445271

446272
/// <summary>

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works</PackageReleaseN
8282
<ItemGroup>
8383
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
8484
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
85-
<PackageReference Include="NumSharp.Lite" Version="0.1.11" />
85+
<PackageReference Include="NumSharp.Lite" Version="0.1.12" />
8686
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
8787
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />
8888
</ItemGroup>

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ public partial class Tensor
4848

4949
public IntPtr TensorDataPointer => _handle == IntPtr.Zero ? IntPtr.Zero : TF_TensorData(_handle);
5050

51+
public Tensor()
52+
{
53+
54+
}
55+
5156
/// <summary>
5257
/// Create a Tensor object from an existing TF handle
5358
/// </summary>
@@ -56,6 +61,9 @@ public Tensor(IntPtr handle)
5661
{
5762
_handle = handle;
5863
//no need to set AllocationType = AllocationType.None;
64+
#if TRACK_TENSOR_LIFE
65+
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
66+
#endif
5967
}
6068

6169
public Tensor(int value)

src/TensorFlowNET.Core/Tensors/Tensor.Value.cs

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ protected unsafe NDArray GetNDArray(TF_DataType dtype)
163163
break;
164164
case TF_DataType.TF_STRING:
165165
return np.array(StringBytes()[0]);
166+
case TF_DataType.TF_UINT8:
167+
storage = new UnmanagedStorage(NPTypeCode.Byte);
168+
break;
166169
case TF_DataType.TF_INT32:
167170
storage = new UnmanagedStorage(NPTypeCode.Int32);
168171
break;
@@ -186,31 +189,13 @@ protected unsafe NDArray GetNDArray(TF_DataType dtype)
186189
return new NDArray(storage);
187190
}
188191

189-
/*protected unsafe NDArray GetScalar(TF_DataType dtype)
190-
{
191-
switch(dtype)
192-
{
193-
case TF_DataType.TF_STRING:
194-
return (NDArray)StringData()[0];
195-
case TF_DataType.TF_INT32:
196-
return *(int*)buffer;
197-
case TF_DataType.TF_FLOAT:
198-
return *(float*)buffer;
199-
case TF_DataType.TF_DOUBLE:
200-
return *(double*)buffer;
201-
default:
202-
return BufferToArray();
203-
}
204-
}*/
205-
206192
/// <summary>
207193
/// Copies the memory of current buffer onto newly allocated array.
208194
/// </summary>
209195
/// <returns></returns>
210196
public unsafe byte[] BufferToArray()
211197
{
212198
// ReSharper disable once LocalVariableHidesMember
213-
var bytesize = (long)this.bytesize;
214199
var data = new byte[bytesize];
215200
fixed (byte* dst = data)
216201
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ public Tensor assign<T>(T value, bool use_locking = false, string name = null, b
100100
if (read_value)
101101
return gen_resource_variable_ops.read_variable_op(handle, dtype);
102102

103+
if (assign_op == null)
104+
return null;
105+
103106
return assign_op;
104107
}
105108

src/TensorFlowNET.Keras/Engine/Model.Training.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,35 @@
33
using System.Text;
44
using HDF.PInvoke;
55
using HDF5CSharp;
6+
using NumSharp;
67
using Tensorflow.Keras.Saving;
78

89
namespace Tensorflow.Keras.Engine
910
{
1011
public partial class Model
1112
{
12-
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
13+
public List<(IVariableV1, NDArray)> load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
1314
{
1415
long fileId = Hdf5.OpenFile(filepath, true);
1516

1617
bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
1718
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");
1819

1920
if (!lsuccess && msuccess)
20-
{
2121
fileId = H5G.open(fileId, "model_weights");
22-
}
22+
2323
if (by_name)
24-
{
2524
//fdf5_format.load_weights_from_hdf5_group_by_name();
2625
throw new NotImplementedException("");
27-
}
2826
else
2927
{
30-
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
28+
var weights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
29+
Hdf5.CloseFile(fileId);
30+
// return a reference to prevent GC collect Variable.
31+
return weights;
3132
}
32-
Hdf5.CloseFile(fileId);
3333
}
34+
3435
public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null)
3536
{
3637
long fileId = Hdf5.CreateFile(filepath);

0 commit comments

Comments
 (0)