Skip to content

Commit f5a3551

Browse files
committed
Refactor RunInAutoMode2.
1 parent 90638a8 commit f5a3551

File tree

8 files changed

+155
-174
lines changed

8 files changed

+155
-174
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class AutoModeArgs
8+
{
9+
public Func<Operation, object> GetGradientAttrs { get; set; }
10+
public object OpInputArgs { get; set; }
11+
public object OpAttrs { get; set; }
12+
}
13+
}

src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using Tensorflow.Eager;
2121
using static Tensorflow.Binding;
2222
using Google.Protobuf;
23+
using System.Collections.Generic;
2324

2425
namespace Tensorflow.Contexts
2526
{
@@ -57,14 +58,39 @@ public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params objec
5758
}
5859
}
5960
}
60-
61+
6162
// [DebuggerStepThrough]
62-
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
63-
Func<Tensors> eagerAction,
64-
Action<Operation> recordGradient,
65-
Tensors tensors)
63+
public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args)
6664
{
67-
if (tf.Context.has_graph_arg(tensors))
65+
var inputArgs = ConvertToDict(args.OpInputArgs);
66+
var attrDict = ConvertToDict(args.OpAttrs);
67+
68+
Func<Tensor> graphAction = () =>
69+
{
70+
foreach (var attr in attrDict)
71+
inputArgs[attr.Key] = attr.Value;
72+
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output;
73+
};
74+
75+
Func<Tensor> eagerAction = () =>
76+
{
77+
var attrs = new object[attrDict.Count() * 2];
78+
int i = 0;
79+
foreach(var arg in attrDict)
80+
{
81+
attrs[i]= arg.Key;
82+
attrs[i + 1] = arg.Value;
83+
i += 2;
84+
}
85+
86+
return tf.Runner.TFE_FastPathExecute2(tf.Context, tf.Context.DeviceName,
87+
OpType, Name,
88+
null,
89+
inputArgs.Values.ToArray(),
90+
attrs).FirstOrDefault();
91+
};
92+
93+
if (tf.Context.has_graph_arg(inputArgs.Values))
6894
{
6995
if (executing_eagerly())
7096
{
@@ -77,7 +103,28 @@ public Tensors RunInAutoMode2(Func<Tensors> graphAction,
77103
{
78104
var result = graphAction();
79105
if (tf.Runner.MustRecordGradient())
80-
recordGradient(result[0].op);
106+
{
107+
var op = result[0].op;
108+
Dictionary<string, object> attrs;
109+
if (args.GetGradientAttrs == null)
110+
{
111+
attrs = new Dictionary<string, object>();
112+
attrs["T"] = op.get_attr<TF_DataType>("T");
113+
}
114+
else
115+
{
116+
attrs = ConvertToDict(args.GetGradientAttrs(op));
117+
}
118+
var args1 = new object[attrs.Count() * 2];
119+
int i = 0;
120+
foreach (var arg in attrs)
121+
{
122+
args1[i] = arg.Key;
123+
args1[i + 1] = arg.Value;
124+
i += 2;
125+
}
126+
tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs);
127+
}
81128
return result;
82129
}
83130
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@ public partial class EagerRunner
1818
int kFastPathExecuteInputStartIndex = 0;
1919
UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>();
2020

21+
public Tensor[] TFE_FastPathExecute2(Context ctx,
22+
string device_name,
23+
string opName,
24+
string name,
25+
Action callbacks,
26+
object[] inputArgs,
27+
object[] attrs)
28+
{
29+
var args = new List<object>();
30+
args.AddRange(inputArgs);
31+
if (attrs != null)
32+
args.AddRange(attrs);
33+
return TFE_FastPathExecute(ctx, device_name, opName, name, callbacks, args.ToArray());
34+
}
35+
2136
public Tensor[] TFE_FastPathExecute(Context ctx,
2237
string device_name,
2338
string opName,

src/TensorFlowNET.Core/Eager/IEagerRunner.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ Tensor[] Execute(Context ctx, string op_name,
1616
TF_DataType default_dtype = TF_DataType.DtInvalid,
1717
object[] args = null);
1818

19+
Tensor[] TFE_FastPathExecute2(Context ctx,
20+
string device_name,
21+
string opName,
22+
string name,
23+
Action callbacks,
24+
object[] inputArgs,
25+
object[] attrs);
26+
1927
Tensor[] TFE_FastPathExecute(Context ctx,
2028
string device_name,
2129
string opName,

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -737,44 +737,35 @@ public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end,
737737
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
738738
long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0,
739739
long shrink_axis_mask = 0, string name = null)
740-
=> tf.Context.RunInAutoMode2(
741-
() => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
740+
=> tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs
741+
{
742+
OpInputArgs = new
742743
{
743744
shape,
744745
begin,
745746
end,
746747
strides,
747-
dy,
748+
dy
749+
},
750+
OpAttrs = new
751+
{
748752
begin_mask,
749753
end_mask,
750754
ellipsis_mask,
751755
new_axis_mask,
752756
shrink_axis_mask
753-
}).output,
754-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
755-
"StridedSliceGrad", name,
756-
null,
757-
shape, begin, end, strides, dy,
758-
"begin_mask", begin_mask,
759-
"end_mask", end_mask,
760-
"ellipsis_mask", ellipsis_mask,
761-
"new_axis_mask", new_axis_mask,
762-
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
763-
(op) =>
764-
{
765-
var attrs = new object[]
766-
{
767-
"T", op.get_attr<TF_DataType>("T"),
768-
"Index", op.get_attr<TF_DataType>("Index"),
769-
"begin_mask", op.get_attr<long>("begin_mask"),
770-
"end_mask", op.get_attr<long>("end_mask"),
771-
"ellipsis_mask", op.get_attr<long>("ellipsis_mask"),
772-
"new_axis_mask", op.get_attr<long>("new_axis_mask"),
773-
"shrink_axis_mask", op.get_attr<long>("shrink_axis_mask")
774-
};
775-
tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs);
776757
},
777-
new Tensors(shape, begin, end, strides, dy));
758+
GetGradientAttrs = (op) => new
759+
{
760+
T = op.get_attr<TF_DataType>("T"),
761+
Index = op.get_attr<TF_DataType>("Index"),
762+
begin_mask = op.get_attr<long>("begin_mask"),
763+
end_mask = op.get_attr<long>("end_mask"),
764+
ellipsis_mask = op.get_attr<long>("ellipsis_mask"),
765+
new_axis_mask = op.get_attr<long>("new_axis_mask"),
766+
shrink_axis_mask = op.get_attr<long>("shrink_axis_mask")
767+
}
768+
});
778769

779770
/// <summary>
780771
/// Removes dimensions of size 1 from the shape of a tensor.
@@ -969,27 +960,15 @@ public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name
969960
=> gen_array_ops.slice(input, begin, size, name: name);
970961

971962
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
972-
=> tf.Context.RunInAutoMode2(
973-
() => tf.OpDefLib._apply_op_helper("Slice", name, new
974-
{
975-
input,
976-
begin,
977-
size
978-
}).output,
979-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
980-
"Slice", name,
981-
null,
982-
input, begin, size).FirstOrDefault(),
983-
(op) =>
963+
=> tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs
964+
{
965+
OpInputArgs = new { input, begin, size },
966+
GetGradientAttrs = (op) => new
984967
{
985-
var attrs = new object[]
986-
{
987-
"T", op.get_attr<TF_DataType>("T"),
988-
"Index", op.get_attr<int>("Index")
989-
};
990-
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
991-
},
992-
new Tensors(input, begin, size));
968+
T = op.get_attr<TF_DataType>("T"),
969+
Index = op.get_attr<int>("Index")
970+
}
971+
});
993972

994973
public static Tensor stack(object values, int axis = 0, string name = "stack")
995974
{

src/TensorFlowNET.Core/Operations/gen_image_ops.cs

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -240,30 +240,16 @@ public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, b
240240

241241
public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false,
242242
bool half_pixel_centers = false, string name = null)
243-
=> tf.Context.RunInAutoMode2(
244-
() => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new
243+
=> tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs
245244
{
246-
grads,
247-
size,
248-
align_corners,
249-
half_pixel_centers
250-
}).output,
251-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
252-
"ResizeNearestNeighborGrad", name,
253-
null,
254-
grads, size,
255-
"align_corners", align_corners,
256-
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
257-
(op) =>
258-
{
259-
var attrs = new object[]
245+
OpInputArgs = new { grads, size },
246+
OpAttrs = new { align_corners, half_pixel_centers },
247+
GetGradientAttrs = (op) => new
260248
{
261-
"T", op.get_attr<TF_DataType>("T"),
262-
"align_corners", op.get_attr<bool>("align_corners"),
263-
"half_pixel_centers", op.get_attr<bool>("half_pixel_centers")
264-
};
265-
tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs);
266-
},
267-
new Tensors(grads, size));
249+
T = op.get_attr<TF_DataType>("T"),
250+
align_corners = op.get_attr<bool>("align_corners"),
251+
half_pixel_centers = op.get_attr<bool>("half_pixel_centers")
252+
}
253+
});
268254
}
269255
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -141,29 +141,17 @@ public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string
141141
/// <param name="name"> A name for the operation (optional).</param>
142142
/// <returns> A `Tensor`. Has the same type as `input`.</returns>
143143
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null)
144-
=> tf.Context.RunInAutoMode2(
145-
() => tf.OpDefLib._apply_op_helper("Mean", name, new
146-
{
147-
input,
148-
reduction_indices = axis,
149-
keep_dims = keep_dims
150-
}).output,
151-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
152-
"Mean", name,
153-
null,
154-
input, axis,
155-
"keep_dims", keep_dims).FirstOrDefault(),
156-
(op) =>
144+
=> tf.Context.RunInAutoMode2("Mean", name, new AutoModeArgs
145+
{
146+
OpInputArgs = new { input, axis },
147+
OpAttrs = new { keep_dims, reduction_indices = axis },
148+
GetGradientAttrs = (op) => new
157149
{
158-
var attrs = new object[]
159-
{
160-
"T", op.get_attr<TF_DataType>("T"),
161-
"Tidx", op.get_attr<TF_DataType>("Tidx"),
162-
"keep_dims", op.get_attr<bool>("keep_dims")
163-
};
164-
tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs);
165-
},
166-
new Tensors(input, axis));
150+
T = op.get_attr<TF_DataType>("T"),
151+
Tidx = op.get_attr<TF_DataType>("Tidx"),
152+
keep_dims = op.get_attr<bool>("keep_dims")
153+
}
154+
});
167155

168156
public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null)
169157
{
@@ -356,21 +344,10 @@ public static Tensor sigmoid(Tensor x, string name = "Sigmoid")
356344
/// <c>dy</c> is the corresponding input gradient.
357345
/// </remarks>
358346
public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad")
359-
=> tf.Context.RunInAutoMode2(
360-
() => tf.OpDefLib._apply_op_helper("SigmoidGrad", name, new { y, dy }).output,
361-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
362-
"SigmoidGrad", name,
363-
null,
364-
y, dy).FirstOrDefault(),
365-
(op) =>
366-
{
367-
var attrs = new object[]
368-
{
369-
"T", op.get_attr<TF_DataType>("T")
370-
};
371-
tf.Runner.RecordGradient("SigmoidGrad", op.inputs, attrs, op.outputs);
372-
},
373-
new Tensors(y, dy));
347+
=> tf.Context.RunInAutoMode2("SigmoidGrad", name, new AutoModeArgs
348+
{
349+
OpInputArgs = new { y, dy }
350+
});
374351

375352
public static Tensor sign<T>(T x, string name = "Sign")
376353
{
@@ -806,21 +783,10 @@ public static Tensor sqrt(Tensor x, string name = null)
806783
}
807784

808785
public static Tensor sub(Tensor x, Tensor y, string name = null)
809-
=> tf.Context.RunInAutoMode2(
810-
() => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output,
811-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
812-
"Sub", name,
813-
null,
814-
x, y).FirstOrDefault(),
815-
(op) =>
816-
{
817-
var attrs = new object[]
818-
{
819-
"T", op.get_attr<TF_DataType>("T")
820-
};
821-
tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs);
822-
},
823-
new Tensors(x, y));
786+
=> tf.Context.RunInAutoMode2("Sub", name, new AutoModeArgs
787+
{
788+
OpInputArgs = new { x, y }
789+
});
824790

825791
public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
826792
{

0 commit comments

Comments
 (0)