Skip to content

Commit 5a0648a

Browse files
authored
Merge pull request #428 from harshithapv/GrandientDescentOptimizerChange
Fixed a bug that over writes the learning rate when sent as a Tensor.
2 parents 203b0e2 + f0ff82a commit 5a0648a

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,29 @@ public class GradientDescentOptimizer : Optimizer
3535
/// for changing these values across different invocations of optimizer
3636
/// functions.
3737
/// </remarks>
38+
private bool _useTensor;
3839
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
3940
: base(learning_rate, use_locking, name)
4041
{
4142
_lr = learning_rate;
43+
_useTensor = false;
4244
}
4345

4446
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
4547
: base(learning_rate, use_locking, name)
4648
{
4749
_lr_t = learning_rate;
50+
_useTensor = true;
4851
}
4952

5053
public override void _prepare()
5154
{
52-
var lr = _call_if_callable(_lr);
53-
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
55+
if(!_useTensor)
56+
{
57+
var lr = _call_if_callable(_lr);
58+
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
59+
}
60+
5461
}
5562
}
5663
}

0 commit comments

Comments
 (0)