Skip to content

Commit 59b7eb0

Browse files
committed
MaybeCreateControlFlowState
1 parent 6794925 commit 59b7eb0

File tree

5 files changed

+232
-108
lines changed

5 files changed

+232
-108
lines changed

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 95 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
5555
* is more than one.
5656
**/
5757
var grads = new Dictionary<string, List<List<Tensor>>>();
58+
Operation[] reachable_to_ops = null;
59+
ControlFlowState loop_state = null;
60+
Dictionary<string, int> pending_count = null;
5861

5962
tf_with(ops.name_scope(name, "gradients",
6063
values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope =>
@@ -81,7 +84,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
8184
var to_ops = ys.Select(x => x.op).ToList();
8285
var from_ops = xs.Select(x => x.op).ToList();
8386
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
84-
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
87+
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
8588

8689
// Add the initial gradients for the ys.
8790
foreach (var (y, grad_y) in zip(ys, grad_ys))
@@ -120,126 +123,135 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
120123
{
121124
// generate gradient subgraph for op.
122125
var op = queue.Dequeue();
123-
if(op.name == "rnn/while/basic_rnn_cell/Tanh")
126+
if(op.name == "rnn/while/Exit")
124127
{
125128

126129
}
127130
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
128-
//if (loop_state != null)
129-
//loop_state.EnterGradWhileContext(op, before: true);
130-
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);
131-
132-
Tensor[] in_grads = null;
133-
var is_partitioned_call = _IsPartitionedCall(op);
134-
var is_func_call = false;
135-
var has_out_grads = out_grads.Exists(x => x != null);
136-
if (has_out_grads && !stop_ops.Contains(op))
137131
{
138-
// A grad_fn must be defined, either as a function or as None
139-
// for ops that do not have gradients.
132+
if (loop_state != null)
133+
loop_state.EnterGradWhileContext(op, before: true);
134+
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);
135+
if (loop_state != null)
136+
loop_state.ExitGradWhileContext(op, before: true);
140137

141-
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
142-
try
143-
{
144-
grad_fn = ops.get_gradient_function(op);
145-
}
146-
catch (LookupError)
138+
Tensor[] in_grads = null;
139+
var is_partitioned_call = _IsPartitionedCall(op);
140+
var is_func_call = false;
141+
var has_out_grads = out_grads.Exists(x => x != null);
142+
if (has_out_grads && !stop_ops.Contains(op))
147143
{
148-
if (is_func_call)
144+
// A grad_fn must be defined, either as a function or as None
145+
// for ops that do not have gradients.
146+
147+
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
148+
try
149149
{
150-
if (is_partitioned_call)
150+
grad_fn = ops.get_gradient_function(op);
151+
}
152+
catch (LookupError)
153+
{
154+
if (is_func_call)
151155
{
156+
if (is_partitioned_call)
157+
{
158+
159+
}
160+
else
161+
{
152162

163+
}
153164
}
154165
else
155166
{
156-
167+
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
157168
}
158169
}
159-
else
160-
{
161-
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
162-
}
163-
}
164170

165-
if (loop_state != null)
166-
loop_state.EnterGradWhileContext(op, before: false);
171+
if (loop_state != null)
172+
loop_state.EnterGradWhileContext(op, before: false);
167173

168-
if ((is_func_call || grad_fn != null) && has_out_grads)
169-
{
170-
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
171-
// output, it means that the cost does not depend on output[i],
172-
// therefore dC/doutput[i] is 0.
173-
foreach (var (i, out_grad) in enumerate(out_grads))
174+
if ((is_func_call || grad_fn != null) && has_out_grads)
174175
{
175-
if (out_grad == null &&
176-
(grad_fn == null || _IsTrainable(op.outputs[i])))
176+
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
177+
// output, it means that the cost does not depend on output[i],
178+
// therefore dC/doutput[i] is 0.
179+
foreach (var (i, out_grad) in enumerate(out_grads))
177180
{
178-
// Only trainable outputs or outputs for a function call that
179-
// will use SymbolicGradient get a zero gradient. Gradient
180-
// functions should ignore the gradient for other outputs.
181-
if (loop_state != null)
182-
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
183-
else
184-
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
181+
if (out_grad == null &&
182+
(grad_fn == null || _IsTrainable(op.outputs[i])))
183+
{
184+
// Only trainable outputs or outputs for a function call that
185+
// will use SymbolicGradient get a zero gradient. Gradient
186+
// functions should ignore the gradient for other outputs.
187+
if (loop_state != null)
188+
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
189+
else
190+
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
191+
}
185192
}
186-
}
187193

188-
tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
189-
{
190-
if (grad_fn != null)
194+
tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
191195
{
192-
in_grads = _MaybeCompile(grad_scope,
193-
op,
194-
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
195-
null,
196-
grad_fn);
197-
}
198-
else
199-
{
200-
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
201-
}
202-
_VerifyGeneratedGradients(in_grads, op);
203-
if (gate_gradients && in_grads.Count(x => x != null) > 1)
204-
{
205-
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
206-
in_grads = control_flow_ops.tuple(in_grads);
207-
}
208-
});
196+
if (grad_fn != null)
197+
{
198+
in_grads = _MaybeCompile(grad_scope,
199+
op,
200+
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
201+
null,
202+
grad_fn);
203+
}
204+
else
205+
{
206+
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
207+
}
208+
_VerifyGeneratedGradients(in_grads, op);
209+
if (gate_gradients && in_grads.Count(x => x != null) > 1)
210+
{
211+
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
212+
in_grads = control_flow_ops.tuple(in_grads);
213+
}
214+
});
215+
}
216+
else
217+
{
218+
// If no grad_fn is defined or none of out_grads is available,
219+
// just propagate a list of None backwards.
220+
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
221+
}
209222
}
210223
else
211224
{
212-
// If no grad_fn is defined or none of out_grads is available,
213-
// just propagate a list of None backwards.
214225
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
215226
}
216-
}
217-
else
218-
{
219-
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
220-
}
221227

222-
var inputs = _NonEagerInputs(op, xs).ToList();
223-
foreach (var (t_in, in_grad) in zip(inputs, in_grads))
224-
{
225-
if (in_grad != null)
228+
var inputs = _NonEagerInputs(op, xs).ToList();
229+
foreach (var (t_in, in_grad) in zip(inputs, in_grads))
226230
{
227-
if (!(in_grad is null) &&
228-
in_grad.Tag == null && // maybe a IndexedSlice
229-
t_in.dtype != TF_DataType.TF_RESOURCE)
231+
if (in_grad != null)
230232
{
231-
in_grad.set_shape(t_in.TensorShape);
232-
}
233+
if (!(in_grad is null) &&
234+
in_grad.Tag == null && // maybe a IndexedSlice
235+
t_in.dtype != TF_DataType.TF_RESOURCE)
236+
{
237+
in_grad.set_shape(t_in.TensorShape);
238+
}
233239

234-
_SetGrad(grads, t_in, in_grad);
240+
_SetGrad(grads, t_in, in_grad);
241+
}
235242
}
236-
}
237243

244+
if (loop_state != null)
245+
loop_state.ExitGradWhileContext(op, before: false);
246+
}
247+
238248
// Update pending count for the inputs of op and enqueue ready ops.
239249
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
240250
}
241251
});
242252

253+
if (loop_state != null)
254+
loop_state.PostProcessing();
243255
return xs.Select(x => _GetGrad(grads, x)).ToArray();
244256
}
245257

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ public Layer(bool trainable = true,
5050

5151
public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null)
5252
{
53-
return __call__(inputs, training: training);
53+
var results = __call__(inputs, training: training);
54+
return (results[0], results[1]);
5455
}
5556

56-
public (Tensor, Tensor) __call__(Tensor inputs,
57+
public Tensor[] __call__(Tensor inputs,
5758
Tensor training = null,
5859
Tensor state = null,
5960
VariableScope scope = null)
@@ -73,7 +74,7 @@ public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null)
7374
auxiliary_name_scope: false);
7475
}
7576

76-
(Tensor, Tensor) outputs = (null, null);
77+
Tensor[] outputs = null;
7778
tf_with(scope_context_manager, scope2 =>
7879
{
7980
_current_scope = scope2;

0 commit comments

Comments
 (0)