Skip to content

Commit 0471e28

Browse files
committed
Change RunInAutoMode to ExecuteOp
1 parent f5a3551 commit 0471e28

File tree

7 files changed

+141
-267
lines changed

7 files changed

+141
-267
lines changed

src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,49 +30,19 @@ namespace Tensorflow.Contexts
3030
public sealed partial class Context
3131
{
3232
// [DebuggerStepThrough]
33-
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args)
34-
{
35-
if (tf.Context.has_graph_arg(args))
36-
{
37-
if (executing_eagerly())
38-
{
39-
graph_mode();
40-
var result = graphAction();
41-
restore_mode();
42-
return result;
43-
}
44-
else
45-
{
46-
return graphAction();
47-
}
48-
}
49-
else
50-
{
51-
if (tf.Context.executing_eagerly())
52-
{
53-
return eagerAction();
54-
}
55-
else
56-
{
57-
return graphAction();
58-
}
59-
}
60-
}
61-
62-
// [DebuggerStepThrough]
63-
public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args)
33+
public Tensors ExecuteOp(string OpType, string Name, AutoModeArgs args)
6434
{
6535
var inputArgs = ConvertToDict(args.OpInputArgs);
6636
var attrDict = ConvertToDict(args.OpAttrs);
6737

68-
Func<Tensor> graphAction = () =>
38+
Func<Tensors> graphAction = () =>
6939
{
7040
foreach (var attr in attrDict)
7141
inputArgs[attr.Key] = attr.Value;
72-
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output;
42+
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).outputs;
7343
};
7444

75-
Func<Tensor> eagerAction = () =>
45+
Func<Tensors> eagerAction = () =>
7646
{
7747
var attrs = new object[attrDict.Count() * 2];
7848
int i = 0;
@@ -87,7 +57,7 @@ public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args)
8757
OpType, Name,
8858
null,
8959
inputArgs.Values.ToArray(),
90-
attrs).FirstOrDefault();
60+
attrs);
9161
};
9262

9363
if (tf.Context.has_graph_arg(inputArgs.Values))

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -269,29 +269,24 @@ public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params)
269269
}
270270

271271
public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params)
272-
=> tf.Context.RunInAutoMode(()
273-
=> tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name,
274-
args: new
275-
{
276-
y_backprop = @params.YBackprop,
277-
x = @params.X,
278-
scale = @params.Scale,
279-
reserve_space_1 = @params.ReserveSpace1,
280-
reserve_space_2 = @params.ReserveSpace2,
281-
reserve_space_3 = @params.ReserveSpace3,
282-
epsilon = @params.Epsilon,
283-
data_format = @params.DataFormat,
284-
is_training = @params.IsTraining
285-
}).outputs, ()
286-
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
287-
"FusedBatchNormGradV3", @params.Name,
288-
null,
289-
@params.YBackprop, @params.X, @params.Scale,
290-
@params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3,
291-
"epsilon", @params.Epsilon,
292-
"data_format", @params.DataFormat,
293-
"is_training", @params.IsTraining),
294-
@params.YBackprop);
272+
=> tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, new AutoModeArgs
273+
{
274+
OpInputArgs = new
275+
{
276+
y_backprop = @params.YBackprop,
277+
x = @params.X,
278+
scale = @params.Scale,
279+
reserve_space_1 = @params.ReserveSpace1,
280+
reserve_space_2 = @params.ReserveSpace2,
281+
reserve_space_3 = @params.ReserveSpace3
282+
},
283+
OpAttrs = new
284+
{
285+
epsilon = @params.Epsilon,
286+
data_format = @params.DataFormat,
287+
is_training = @params.IsTraining
288+
}
289+
});
295290

296291
public static Tensor[] fused_batch_norm(Tensor x,
297292
Tensor scale,
@@ -388,14 +383,10 @@ public static Tensor local_response_normalization(Tensor input, int depth_radius
388383
}
389384

390385
public static Tensor log_softmax(Tensor logits, string name = null)
391-
=> tf.Context.RunInAutoMode(()
392-
=> tf.OpDefLib._apply_op_helper("LogSoftmax", name: name,
393-
args: new { logits }).output, ()
394-
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
395-
"LogSoftmax", name,
396-
null,
397-
logits).FirstOrDefault(),
398-
logits);
386+
=> tf.Context.ExecuteOp("LogSoftmax", name, new AutoModeArgs
387+
{
388+
OpInputArgs = new { logits }
389+
});
399390

400391
/// <summary>
401392
/// Says whether the targets are in the top `K` predictions.
@@ -418,19 +409,11 @@ public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, strin
418409
}
419410

420411
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
421-
=> tf.Context.RunInAutoMode(()
422-
=> tf.OpDefLib._apply_op_helper("LeakyRelu", name: name,
423-
args: new
424-
{
425-
features,
426-
alpha
427-
}).output, ()
428-
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
429-
"LeakyRelu", name,
430-
null,
431-
features,
432-
"alpha", alpha).FirstOrDefault(),
433-
features);
412+
=> tf.Context.ExecuteOp("LeakyRelu", name, new AutoModeArgs
413+
{
414+
OpInputArgs = new { features },
415+
OpAttrs = new { alpha }
416+
});
434417

435418
public static Tensor max_pool(Tensor input,
436419
int[] ksize,

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end,
737737
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
738738
long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0,
739739
long shrink_axis_mask = 0, string name = null)
740-
=> tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs
740+
=> tf.Context.ExecuteOp("StridedSliceGrad", name, new AutoModeArgs
741741
{
742742
OpInputArgs = new
743743
{
@@ -960,7 +960,7 @@ public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name
960960
=> gen_array_ops.slice(input, begin, size, name: name);
961961

962962
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
963-
=> tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs
963+
=> tf.Context.ExecuteOp("Slice", name, new AutoModeArgs
964964
{
965965
OpInputArgs = new { input, begin, size },
966966
GetGradientAttrs = (op) => new

0 commit comments

Comments
 (0)