Skip to content

Commit a0df810

Browse files
committed
fix: training LSTM does not align with tensorflow.
1 parent 675b93a commit a0df810

File tree

14 files changed

+68
-37
lines changed

14 files changed

+68
-37
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ public static TF_DataType GetDataType(this object data)
503503
case Tensors tensors:
504504
return tensors.dtype;
505505
case IEnumerable<Tensor> tensors:
506-
return tensors.First().dtype;
506+
return tensors.Where(x => x is not null).First().dtype;
507507
case RefVariable variable:
508508
return variable.dtype;
509509
case ResourceVariable variable:

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public Tensor[] TFE_TapeGradient(ITape tape,
6565
{
6666
outgrad_vec = output_gradients.ToList();
6767
}
68-
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
68+
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);
6969

7070

7171
bool unconnected_gradients_zero = unconnected_gradients == "zero";

src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ public override string ToString()
1010
var str = NDArrayRender.ToString(nd);
1111
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
1212
}
13-
13+
public string ToString(int maxLength)
14+
{
15+
var nd = new NDArray(this);
16+
var str = NDArrayRender.ToString(nd, maxLength);
17+
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
18+
}
1419
}
1520
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class LSTMCellArgs : AutoSerializeLayerArgs
2929
[JsonProperty("unit_forget_bias")]
3030
public bool UnitForgetBias { get; set; } = true;
3131
[JsonProperty("implementation")]
32-
public int Implementation { get; set; } = 1;
32+
public int Implementation { get; set; } = 2;
3333

3434
}
3535
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public ILayer LSTM(int units,
182182
bool unit_forget_bias = true,
183183
float dropout = 0f,
184184
float recurrent_dropout = 0f,
185-
int implementation = 1,
185+
int implementation = 2,
186186
bool return_sequences = false,
187187
bool return_state = false,
188188
bool go_backwards = false,

src/TensorFlowNET.Core/NumPy/NDArrayRender.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@ namespace Tensorflow.NumPy
77
{
88
public class NDArrayRender
99
{
10-
public static string ToString(NDArray array)
10+
public static string ToString(NDArray array, int maxLength = 10)
1111
{
1212
Shape shape = array.shape;
1313
if (shape.IsScalar)
1414
return Render(array);
1515

1616
var s = new StringBuilder();
1717
s.Append("array(");
18-
Build(s, array);
18+
Build(s, array, maxLength);
1919
s.Append(")");
2020
return s.ToString();
2121
}
2222

23-
static void Build(StringBuilder s, NDArray array)
23+
static void Build(StringBuilder s, NDArray array, int maxLength)
2424
{
2525
var shape = array.shape;
2626

@@ -35,11 +35,11 @@ static void Build(StringBuilder s, NDArray array)
3535
var len = shape[0];
3636
s.Append("[");
3737

38-
if (len <= 10)
38+
if (len <= maxLength)
3939
{
4040
for (int i = 0; i < len; i++)
4141
{
42-
Build(s, array[i]);
42+
Build(s, array[i], maxLength);
4343
if (i < len - 1)
4444
{
4545
s.Append(", ");
@@ -49,9 +49,9 @@ static void Build(StringBuilder s, NDArray array)
4949
}
5050
else
5151
{
52-
for (int i = 0; i < 5; i++)
52+
for (int i = 0; i < maxLength / 2; i++)
5353
{
54-
Build(s, array[i]);
54+
Build(s, array[i], maxLength);
5555
if (i < len - 1)
5656
{
5757
s.Append(", ");
@@ -62,9 +62,9 @@ static void Build(StringBuilder s, NDArray array)
6262
s.Append(" ... ");
6363
s.AppendLine();
6464

65-
for (int i = (int)len - 5; i < len; i++)
65+
for (int i = (int)len - maxLength / 2; i < len; i++)
6666
{
67-
Build(s, array[i]);
67+
Build(s, array[i], maxLength);
6868
if (i < len - 1)
6969
{
7070
s.Append(", ");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.NumPy;
5+
6+
namespace Tensorflow.Operations.Initializers
7+
{
8+
/// <summary>
9+
/// An initializer specially used for debugging (to load weights from disk).
10+
/// </summary>
11+
class NpyLoadInitializer : IInitializer
12+
{
13+
string _path;
14+
public NpyLoadInitializer(string path) { _path = path; }
15+
public string ClassName => "";
16+
public IDictionary<string, object> Config => new Dictionary<string, object>();
17+
public Tensor Apply(InitializerArgs args)
18+
{
19+
return np.load(_path);
20+
}
21+
}
22+
}

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ https://tensorflownet.readthedocs.io</Description>
111111
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
112112
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
113113
<PackageReference Include="OneOf" Version="3.0.223" />
114-
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
114+
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
115115
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
116116
</ItemGroup>
117117

src/TensorFlowNET.Core/Training/Trackable.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args
179179
// handles slot variables.
180180
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
181181
{
182-
var temp = new_variable as Trackable;
183-
var res = _track_trackable(temp, args.Name, args.Overwrite);
182+
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite);
184183
Debug.Assert(res is IVariableV1);
185184
return res as IVariableV1;
186185
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ public IRnnCell LSTMCell(int uints,
793793
bool unit_forget_bias = true,
794794
float dropout = 0f,
795795
float recurrent_dropout = 0f,
796-
int implementation = 1)
796+
int implementation = 2)
797797
=> new LSTMCell(new LSTMCellArgs
798798
{
799799
Units = uints,
@@ -846,7 +846,7 @@ public ILayer LSTM(int units,
846846
bool unit_forget_bias = true,
847847
float dropout = 0f,
848848
float recurrent_dropout = 0f,
849-
int implementation = 1,
849+
int implementation = 2,
850850
bool return_sequences = false,
851851
bool return_state = false,
852852
bool go_backwards = false,
@@ -869,7 +869,8 @@ public ILayer LSTM(int units,
869869
GoBackwards = go_backwards,
870870
Stateful = stateful,
871871
TimeMajor = time_major,
872-
Unroll = unroll
872+
Unroll = unroll,
873+
UnitForgetBias = unit_forget_bias
873874
});
874875

875876
/// <summary>

0 commit comments

Comments
 (0)