Skip to content

Commit 93eb56e

Browse files
committed
add apply_adam, _apply_dense for Adam. #271
1 parent 2a17b9c commit 93eb56e

File tree

8 files changed

+146
-110
lines changed

8 files changed

+146
-110
lines changed

src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ public _InitializeClustersOpFactory(Tensor[] inputs,
4747
_cluster_centers_updated = cluster_centers_updated;
4848
_cluster_centers_initialized = cluster_centers_initialized;
4949

50-
_num_selected = array_ops.shape(_cluster_centers)[0];
50+
_num_selected = array_ops.shape(_cluster_centers).slice(0);
5151
_num_remaining = _num_clusters - _num_selected;
5252

53-
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray());
53+
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i).slice(0)).ToArray());
5454
}
5555

5656
private Tensor _initialize()
@@ -68,7 +68,7 @@ private Tensor _initialize()
6868
},
6969
() =>
7070
{
71-
return control_flow_ops.no_op().output[0];
71+
return control_flow_ops.no_op().output.slice(0);
7272
});
7373
});
7474
}
@@ -90,7 +90,7 @@ private Tensor _add_new_centers()
9090
// Adds some centers and returns the number of centers remaining.
9191
var new_centers = _choose_initial_centers();
9292
if (_distance_metric == KMeans.COSINE_DISTANCE)
93-
new_centers = nn_impl.l2_normalize(new_centers[0], axis: 1);
93+
new_centers = nn_impl.l2_normalize(new_centers.slice(0), axis: 1);
9494

9595
// If cluster_centers is empty, it doesn't have the right shape for concat.
9696
var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0),
@@ -99,12 +99,12 @@ private Tensor _add_new_centers()
9999

100100
var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false);
101101

102-
return _num_clusters - array_ops.shape(a)[0];
102+
return _num_clusters - array_ops.shape(a).slice(0);
103103
}
104104

105105
private Tensor _choose_initial_centers()
106106
{
107-
return _greedy_batch_sampler()[0];
107+
return _greedy_batch_sampler().slice(0);
108108
}
109109

110110
private Tensor _greedy_batch_sampler()

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads)
156156
// For axis 0 gathers, build an appropriately shaped IndexedSlices.
157157
if((int)axis_static == 0)
158158
{
159-
var params_tail_shape = params_shape[new NumSharp.Slice(start:1)];
159+
var params_tail_shape = params_shape.slice(new NumSharp.Slice(start:1));
160160
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
161161
var values = array_ops.reshape(grad, values_shape);
162162
indices = array_ops.reshape(indices, indices_size);

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,16 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
105105
var has_out_grads = true;
106106
if (has_out_grads && !stop_ops.Contains(op))
107107
{
108+
// A grad_fn must be defined, either as a function or as None
109+
// for ops that do not have gradients.
110+
var grad_fn = ops.get_gradient_function(op);
111+
108112
if (is_func_call)
109113
{
110114

111115
}
112116
else
113117
{
114-
// A grad_fn must be defined, either as a function or as None
115-
// for ops that do not have gradients.
116-
var grad_fn = ops.get_gradient_function(op);
117-
118118
foreach (var (i, out_grad) in enumerate(out_grads))
119119
{
120120
if (out_grad == null)
@@ -322,7 +322,7 @@ private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>>
322322
else
323323
{
324324
used = "add_n";
325-
out_grads[i] = new List<Tensor> { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) };
325+
return_grads[i] = _MultiDeviceAddN(out_grad.ToArray(), gradient_uid);
326326
}
327327
}
328328
else

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
200200

201201
var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64),
202202
array_ops.size(in_shape) - 1);
203-
var outerdim = array_ops.shape(ind_2d)[0];
203+
var outerdim = array_ops.shape(ind_2d).slice(0);
204204

205205
// Compute linear indices(flattened to 1D).
206206
var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64);

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 89 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -224,116 +224,110 @@ public TF_DataType ToTFDataType(Type type)
224224
}
225225
}
226226

227-
public Tensor this[Slice slice]
227+
public Tensor slice(Slice slice)
228228
{
229-
get
230-
{
231-
var slice_spec = new int[] { slice.Start.Value };
232-
var begin = new List<int>();
233-
var end = new List<int>();
234-
var strides = new List<int>();
229+
var slice_spec = new int[] { slice.Start.Value };
230+
var begin = new List<int>();
231+
var end = new List<int>();
232+
var strides = new List<int>();
235233

236-
var index = 0;
237-
var (new_axis_mask, shrink_axis_mask) = (0, 0);
238-
var (begin_mask, end_mask) = (0, 0);
239-
var ellipsis_mask = 0;
234+
var index = 0;
235+
var (new_axis_mask, shrink_axis_mask) = (0, 0);
236+
var (begin_mask, end_mask) = (0, 0);
237+
var ellipsis_mask = 0;
240238

241-
foreach (var s in slice_spec)
239+
foreach (var s in slice_spec)
240+
{
241+
begin.Add(s);
242+
if (slice.Stop.HasValue)
242243
{
243-
begin.Add(s);
244-
if(slice.Stop.HasValue)
245-
{
246-
end.Add(slice.Stop.Value);
247-
}
248-
else
249-
{
250-
end.Add(0);
251-
end_mask |= (1 << index);
252-
}
253-
strides.Add(slice.Step);
254-
255-
index += 1;
244+
end.Add(slice.Stop.Value);
256245
}
257-
258-
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
246+
else
259247
{
260-
string name = scope;
261-
if (begin != null)
262-
{
263-
var (packed_begin, packed_end, packed_strides) =
264-
(array_ops.stack(begin.ToArray()),
265-
array_ops.stack(end.ToArray()),
266-
array_ops.stack(strides.ToArray()));
267-
268-
return gen_array_ops.strided_slice(
269-
this,
270-
packed_begin,
271-
packed_end,
272-
packed_strides,
273-
begin_mask: begin_mask,
274-
end_mask: end_mask,
275-
shrink_axis_mask: shrink_axis_mask,
276-
new_axis_mask: new_axis_mask,
277-
ellipsis_mask: ellipsis_mask,
278-
279-
name: name);
280-
}
281-
282-
throw new NotImplementedException("");
283-
});
248+
end.Add(0);
249+
end_mask |= (1 << index);
250+
}
251+
strides.Add(slice.Step);
252+
253+
index += 1;
284254
}
255+
256+
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
257+
{
258+
string name = scope;
259+
if (begin != null)
260+
{
261+
var (packed_begin, packed_end, packed_strides) =
262+
(array_ops.stack(begin.ToArray()),
263+
array_ops.stack(end.ToArray()),
264+
array_ops.stack(strides.ToArray()));
265+
266+
return gen_array_ops.strided_slice(
267+
this,
268+
packed_begin,
269+
packed_end,
270+
packed_strides,
271+
begin_mask: begin_mask,
272+
end_mask: end_mask,
273+
shrink_axis_mask: shrink_axis_mask,
274+
new_axis_mask: new_axis_mask,
275+
ellipsis_mask: ellipsis_mask,
276+
277+
name: name);
278+
}
279+
280+
throw new NotImplementedException("");
281+
});
285282
}
286283

287-
public Tensor this[int start]
284+
public Tensor slice(int start)
288285
{
289-
get
290-
{
291-
var slice_spec = new int[] { start };
292-
var begin = new List<int>();
293-
var end = new List<int>();
294-
var strides = new List<int>();
286+
var slice_spec = new int[] { start };
287+
var begin = new List<int>();
288+
var end = new List<int>();
289+
var strides = new List<int>();
290+
291+
var index = 0;
292+
var (new_axis_mask, shrink_axis_mask) = (0, 0);
293+
var (begin_mask, end_mask) = (0, 0);
294+
var ellipsis_mask = 0;
295295

296-
var index = 0;
297-
var (new_axis_mask, shrink_axis_mask) = (0, 0);
298-
var (begin_mask, end_mask) = (0, 0);
299-
var ellipsis_mask = 0;
296+
foreach (var s in slice_spec)
297+
{
298+
begin.Add(s);
299+
end.Add(s + 1);
300+
strides.Add(1);
301+
shrink_axis_mask |= (1 << index);
302+
index += 1;
303+
}
300304

301-
foreach (var s in slice_spec)
305+
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
306+
{
307+
string name = scope;
308+
if (begin != null)
302309
{
303-
begin.Add(s);
304-
end.Add(s + 1);
305-
strides.Add(1);
306-
shrink_axis_mask |= (1 << index);
307-
index += 1;
310+
var (packed_begin, packed_end, packed_strides) =
311+
(array_ops.stack(begin.ToArray()),
312+
array_ops.stack(end.ToArray()),
313+
array_ops.stack(strides.ToArray()));
314+
315+
return gen_array_ops.strided_slice(
316+
this,
317+
packed_begin,
318+
packed_end,
319+
packed_strides,
320+
begin_mask: begin_mask,
321+
end_mask: end_mask,
322+
shrink_axis_mask: shrink_axis_mask,
323+
new_axis_mask: new_axis_mask,
324+
ellipsis_mask: ellipsis_mask,
325+
326+
name: name);
308327
}
309328

310-
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
311-
{
312-
string name = scope;
313-
if (begin != null)
314-
{
315-
var (packed_begin, packed_end, packed_strides) =
316-
(array_ops.stack(begin.ToArray()),
317-
array_ops.stack(end.ToArray()),
318-
array_ops.stack(strides.ToArray()));
319-
320-
return gen_array_ops.strided_slice(
321-
this,
322-
packed_begin,
323-
packed_end,
324-
packed_strides,
325-
begin_mask: begin_mask,
326-
end_mask: end_mask,
327-
shrink_axis_mask: shrink_axis_mask,
328-
new_axis_mask: new_axis_mask,
329-
ellipsis_mask: ellipsis_mask,
330-
331-
name: name);
332-
}
333-
334-
throw new NotImplementedException("");
335-
});
336-
}
329+
throw new NotImplementedException("");
330+
});
337331
}
338332

339333
public override string ToString()

src/TensorFlowNET.Core/Train/AdamOptimizer.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class AdamOptimizer : Optimizer
1616
float _beta1;
1717
float _beta2;
1818
float _epsilon;
19-
Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t;
19+
Tensor _beta1_t, _beta2_t, _epsilon_t;
2020

2121
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
2222
: base(learning_rate, use_locking, name)
@@ -34,6 +34,25 @@ public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
3434
});
3535
}
3636

37+
public override Operation _apply_dense(Tensor grad, RefVariable var)
38+
{
39+
var m = get_slot(var, "m");
40+
var v = get_slot(var, "v");
41+
var (beta1_power, beta2_power) = _get_beta_accumulators();
42+
return gen_training_ops.apply_adam(
43+
var,
44+
m,
45+
v,
46+
math_ops.cast(beta1_power, var.dtype.as_base_dtype()),
47+
math_ops.cast(beta2_power, var.dtype.as_base_dtype()),
48+
math_ops.cast(_lr_t, var.dtype.as_base_dtype()),
49+
math_ops.cast(_beta1_t, var.dtype.as_base_dtype()),
50+
math_ops.cast(_beta2_t, var.dtype.as_base_dtype()),
51+
math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()),
52+
grad,
53+
use_locking: _use_locking).op;
54+
}
55+
3756
private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func<RefVariable, Tensor, Tensor, Tensor> scatter_add)
3857
{
3958
var (beta1_power_v, beta2_power_v) = _get_beta_accumulators();

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var)
272272
public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices)
273273
{
274274
var (unique_indices, new_index_positions) = array_ops.unique(indices);
275-
var shape = array_ops.shape(unique_indices)[0];
275+
var shape = array_ops.shape(unique_indices).slice(0);
276276
var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape);
277277
return (summed_values, unique_indices);
278278
}

src/TensorFlowNET.Core/Train/gen_training_ops.py.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,29 @@ public class gen_training_ops
88
{
99
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
1010

11+
public static Tensor apply_adam(RefVariable var, RefVariable m, RefVariable v, Tensor beta1_power, Tensor beta2_power,
12+
Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad,
13+
bool use_locking = false, bool use_nesterov = false, string name = null)
14+
{
15+
var _op = _op_def_lib._apply_op_helper("ApplyAdam", name, new
16+
{
17+
var,
18+
m,
19+
v,
20+
beta1_power,
21+
beta2_power,
22+
lr,
23+
beta1,
24+
beta2,
25+
epsilon,
26+
grad,
27+
use_locking,
28+
use_nesterov
29+
});
30+
31+
return _op.outputs[0];
32+
}
33+
1134
public static Tensor apply_gradient_descent(RefVariable var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
1235
{
1336
var _op = _op_def_lib._apply_op_helper("ApplyGradientDescent", name, new

0 commit comments

Comments
 (0)