Skip to content

Commit c6f9ec6

Browse files
committed
gradients_impl. not finish yet.
1 parent bcb803d commit c6f9ec6

File tree

7 files changed

+109
-5
lines changed

7 files changed

+109
-5
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class AggregationMethod
8+
{
9+
public static int ADD_N = 0;
10+
public static int DEFAULT = ADD_N;
11+
// The following are experimental and may not be supported in future releases.
12+
public static int EXPERIMENTAL_TREE = 1;
13+
public static int EXPERIMENTAL_ACCUMULATE_N = 2;
14+
}
15+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class gradients_impl
8+
{
9+
public static void gradients(object ys,
10+
object xs,
11+
List<Tensor> grad_ys = null,
12+
string name = "gradients",
13+
bool colocate_gradients_with_ops = false,
14+
bool gate_gradients = false,
15+
int? aggregation_method = null)
16+
{
17+
_GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients);
18+
}
19+
20+
public static void _GradientsHelper(object ys,
21+
object xs,
22+
List<Tensor> grad_ys = null,
23+
string name = "gradients",
24+
bool colocate_gradients_with_ops = false,
25+
bool gate_gradients = false,
26+
Graph src_graph = null)
27+
{
28+
if (src_graph == null)
29+
src_graph = ops.get_default_graph();
30+
}
31+
}
32+
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ public override string ToString()
222222
}
223223
}
224224

225-
return "";
225+
return $"{name} {dtype} {rank} {string.Join(",", shape)}";
226226
}
227227

228228
public void Dispose()

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using distribute_lib = Tensorflow.Distribute;
56

@@ -48,8 +49,10 @@ public Optimizer minimize(Tensor loss,
4849
/// <param name="gate_gradients"></param>
4950
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
5051
List<RefVariable> var_list = null,
52+
int? aggregation_method = null,
5153
GateGradientType gate_gradients = GateGradientType.GATE_OP,
52-
bool colocate_gradients_with_ops = false)
54+
bool colocate_gradients_with_ops = false,
55+
List<Tensor> grad_loss = null)
5356
{
5457
int num_towers = 1;
5558
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
@@ -65,10 +68,13 @@ public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
6568
break;
6669
}
6770

68-
foreach(var v in var_list)
69-
{
71+
var processors = var_list.Select(v => optimizer._get_processor(v));
72+
var var_refs = processors.Select(x => x.target()).ToList();
7073

71-
}
74+
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss,
75+
gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
76+
aggregation_method: aggregation_method,
77+
colocate_gradients_with_ops: colocate_gradients_with_ops);
7278

7379
return null;
7480
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public interface _OptimizableVariable
8+
{
9+
Tensor target();
10+
void update_op(Graph g);
11+
}
12+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class optimizer
8+
{
9+
public static _OptimizableVariable _get_processor(RefVariable v)
10+
{
11+
return new _RefVariableProcessor(v);
12+
}
13+
}
14+
15+
public class _RefVariableProcessor : _OptimizableVariable
16+
{
17+
private RefVariable _v;
18+
19+
public _RefVariableProcessor(RefVariable v)
20+
{
21+
_v = v;
22+
}
23+
24+
public Tensor target()
25+
{
26+
return _v._ref();
27+
}
28+
29+
public void update_op(Graph g)
30+
{
31+
32+
}
33+
}
34+
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ private void _init_from_args(object initial_value,
7878
ops.add_to_collections(collections, this);
7979
}
8080

81+
public Tensor _ref()
82+
{
83+
return _variable;
84+
}
85+
8186
public static implicit operator _VariableScopeStore(RefVariable variable)
8287
{
8388
return null;

0 commit comments

Comments
 (0)