Skip to content

Commit 2093577

Browse files
committed
Implement FuncGraph.capture_eager_tensor.
1 parent 58d2dae commit 2093577

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,32 @@ public Tensor capture(Tensor tensor, string name = null, TensorShape shape = nul
140140
}
141141

142142
Tensor capture_eager_tensor(Tensor tensor, string name)
143-
=> throw new NotImplementedException("");
143+
{
144+
Tensor graph_const = null;
145+
if (!_captures.ContainsKey(tensor.Id))
146+
{
147+
graph_const = tf_with(ops.control_dependencies(null), ctl
148+
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
149+
add_capture(tensor, graph_const);
150+
}
151+
else
152+
{
153+
graph_const = _captures[tensor.Id].Item2;
154+
}
155+
156+
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
157+
{
158+
return output_grads;
159+
};
160+
161+
tf.Runner.RecordGradient("captured_value",
162+
new[] { graph_const }, null,
163+
new[] { tensor },
164+
getBackwardFunction: () => _backward_function_wrapper
165+
/*getForwardFunction: forward_function*/);
166+
167+
return graph_const;
168+
}
144169

145170
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
146171
{

0 commit comments

Comments
 (0)