Skip to content

Commit df91307

Browse files
committed
Fix namespace compile issue.
1 parent 4e78d3d commit df91307

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,13 @@ void apply_gradients((Tensor, IVariableV1) grads_and_vars,
1010
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
1111
string name = null,
1212
bool experimental_aggregate_gradients = true);
13+
14+
void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
15+
string name = null,
16+
bool experimental_aggregate_gradients = true);
17+
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
18+
string name = null,
19+
bool experimental_aggregate_gradients = true);
20+
1321
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null);
1422
}

src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,42 @@ public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
7878
});
7979
}
8080

81+
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
82+
string name = null,
83+
bool experimental_aggregate_gradients = true)
84+
=> apply_gradients(new[] { grads_and_vars },
85+
name: name,
86+
experimental_aggregate_gradients: experimental_aggregate_gradients);
87+
88+
/// <summary>
89+
/// Apply gradients to variables.
90+
/// </summary>
91+
/// <param name="grads_and_vars"></param>
92+
/// <param name="name"></param>
93+
/// <param name="experimental_aggregate_gradients"></param>
94+
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
95+
string name = null,
96+
bool experimental_aggregate_gradients = true)
97+
{
98+
var var_list = grads_and_vars.Select(x => x.Item2).ToArray();
99+
tf_with(ops.name_scope(_name), delegate
100+
{
101+
ops.init_scope();
102+
_create_all_weights(var_list);
103+
if (grads_and_vars == null || grads_and_vars.Count() == 0)
104+
return control_flow_ops.no_op();
105+
106+
var apply_state = _prepare(var_list);
107+
// if(experimental_aggregate_gradients)
108+
{
109+
// var reduced_grads = _aggregate_gradients(grads_and_vars);
110+
_distributed_apply(grads_and_vars.Select(x => (x.Item1, (IVariableV1)x.Item2)), name, apply_state);
111+
}
112+
113+
return null;
114+
});
115+
}
116+
81117
void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
82118
{
83119
_resource_apply_dense(var, grad, apply_state);

test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
using System.Linq;
66
using Tensorflow;
77
using static Tensorflow.Binding;
8-
using Buffer = Tensorflow.Buffer;
9-
using TensorFlowNET.Keras.UnitTest;
8+
using Tensorflow.Keras.UnitTest;
109

1110
namespace TensorFlowNET.UnitTest.Basics
1211
{

test/TensorFlowNET.Graph.UnitTest/SignalTest.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
using System.Linq;
66
using Tensorflow;
77
using static Tensorflow.Binding;
8-
using Buffer = Tensorflow.Buffer;
9-
using TensorFlowNET.Keras.UnitTest;
8+
using Tensorflow.Keras.UnitTest;
109

1110
namespace TensorFlowNET.UnitTest.Basics
1211
{

0 commit comments

Comments
 (0)