Skip to content

Commit 822e6d7

Browse files
committed
Consolidate TFE_FastPathExecute and _apply_op_helper.
1 parent 0471e28 commit 822e6d7

30 files changed

+532
-2403
lines changed

src/TensorFlowNET.Console/MemoryBasicTest.cs

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,18 @@ public Action<int, int> Conv2DWithTensor
112112
var strides = new[] { 1, 1, 1, 1 };
113113
var dilations = new[] { 1, 1, 1, 1 };
114114

115-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
116-
"Conv2D", null,
117-
null,
118-
input, filter,
119-
"strides", strides,
120-
"use_cudnn_on_gpu", true,
121-
"padding", "VALID",
122-
"explicit_paddings", new int[0],
123-
"data_format", "NHWC",
124-
"dilations", dilations);
115+
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter)
116+
{
117+
attrs = ConvertToDict(new
118+
{
119+
strides,
120+
use_cudnn_on_gpu = true,
121+
padding = "VALID",
122+
explicit_paddings = new int[0],
123+
data_format = "NHWC",
124+
dilations
125+
})
126+
});
125127
};
126128

127129
public Action<int, int> Conv2DWithVariable
@@ -132,16 +134,18 @@ public Action<int, int> Conv2DWithVariable
132134
var strides = new[] { 1, 1, 1, 1 };
133135
var dilations = new[] { 1, 1, 1, 1 };
134136

135-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
136-
"Conv2D", null,
137-
null,
138-
input, filter,
139-
"strides", strides,
140-
"use_cudnn_on_gpu", true,
141-
"padding", "VALID",
142-
"explicit_paddings", new int[0],
143-
"data_format", "NHWC",
144-
"dilations", dilations);
137+
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter)
138+
{
139+
attrs = ConvertToDict(new
140+
{
141+
strides,
142+
use_cudnn_on_gpu = true,
143+
padding = "VALID",
144+
explicit_paddings = new int[0],
145+
data_format = "NHWC",
146+
dilations
147+
})
148+
});
145149
};
146150

147151
public Action<int, int> Dataset

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ You may obtain a copy of the License at
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
******************************************************************************/
16+
using static Tensorflow.Binding;
1617

1718
namespace Tensorflow
1819
{
@@ -37,8 +38,8 @@ public Tensor diag(Tensor diagonal, string name = null)
3738
public Tensor matmul(Tensor a, Tensor b)
3839
=> math_ops.matmul(a, b);
3940

40-
public Tensor batch_matmul(Tensor x, Tensor y)
41-
=> gen_math_ops.batch_mat_mul(x, y);
41+
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
42+
=> tf.Context.ExecuteOp("BatchMatMul", name, new ExecuteOpArgs(x, y).SetAttributes(new { adj_x, adj_y }));
4243
}
4344

4445
public Tensor diag(Tensor diagonal, string name = null)
@@ -47,7 +48,32 @@ public Tensor diag(Tensor diagonal, string name = null)
4748
public Tensor matmul(Tensor a, Tensor b)
4849
=> math_ops.matmul(a, b);
4950

50-
public Tensor batch_matmul(Tensor x, Tensor y)
51-
=> gen_math_ops.batch_mat_mul(x, y);
51+
/// <summary>
52+
/// Multiply slices of the two matrices "x" and "y".
53+
/// </summary>
54+
/// <remarks>
55+
/// The `BatchMatMul` operation is embedded into the
56+
/// `MatMul` operation on the DLL side. However the expected
57+
/// attributes are not the same, hence we need to expose this
58+
/// method to have the right args list on the `_apply_op_helper`
59+
/// function.
60+
///
61+
/// For each rank > 2 the first rank - 2 dimensions are considered
62+
/// as fixed, and have to be consistent across the two matrices. A
63+
/// common matrix multiplication is then applied over the residual
64+
/// 2 dimensions.
65+
///
66+
/// e.g.
67+
/// x is (3, 6, 12); y is (3, 12, 6)
68+
/// batch_matmul(x, y) ==> (3, 6, 6)
69+
/// </remarks>
70+
/// <param name="x"></param>
71+
/// <param name="y"></param>
72+
/// <param name="adj_x"></param>
73+
/// <param name="adj_y"></param>
74+
/// <param name="name"></param>
75+
/// <returns></returns>
76+
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
77+
=> tf.Context.ExecuteOp("BatchMatMul", name, new ExecuteOpArgs(x, y).SetAttributes(new { adj_x, adj_y }));
5278
}
5379
}

src/TensorFlowNET.Core/Contexts/AutoModeArgs.cs

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs renamed to src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,37 +30,35 @@ namespace Tensorflow.Contexts
3030
public sealed partial class Context
3131
{
3232
// [DebuggerStepThrough]
33-
public Tensors ExecuteOp(string OpType, string Name, AutoModeArgs args)
33+
public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args)
3434
{
35-
var inputArgs = ConvertToDict(args.OpInputArgs);
36-
var attrDict = ConvertToDict(args.OpAttrs);
37-
3835
Func<Tensors> graphAction = () =>
3936
{
40-
foreach (var attr in attrDict)
41-
inputArgs[attr.Key] = attr.Value;
42-
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).outputs;
37+
var keywords = new Dictionary<string, object>();
38+
if(args.OpInputArgs != null)
39+
{
40+
foreach (var (i, input) in enumerate(args.OpInputArgs))
41+
keywords[$"input_{i}"] = input;
42+
}
43+
44+
if(args.OpAttrs != null)
45+
{
46+
foreach (var attr in args.OpAttrs)
47+
keywords[attr.Key] = attr.Value;
48+
}
49+
50+
return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
4351
};
4452

4553
Func<Tensors> eagerAction = () =>
4654
{
47-
var attrs = new object[attrDict.Count() * 2];
48-
int i = 0;
49-
foreach(var arg in attrDict)
55+
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
5056
{
51-
attrs[i]= arg.Key;
52-
attrs[i + 1] = arg.Value;
53-
i += 2;
54-
}
55-
56-
return tf.Runner.TFE_FastPathExecute2(tf.Context, tf.Context.DeviceName,
57-
OpType, Name,
58-
null,
59-
inputArgs.Values.ToArray(),
60-
attrs);
57+
attrs = args.OpAttrs
58+
});
6159
};
6260

63-
if (tf.Context.has_graph_arg(inputArgs.Values))
61+
if (tf.Context.has_graph_arg(args.OpInputArgs))
6462
{
6563
if (executing_eagerly())
6664
{

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ public bool switched_to_graph(params object[] args)
115115
public bool has_graph_arg(params object[] args)
116116
{
117117
var flatten_args = nest.flatten<object>(args);
118-
bool has_graph_arg = false;
118+
/*if (flatten_args.Count(x => x.GetType().IsValueType) == flatten_args.Count())
119+
return tf.Context.executing_eagerly() == false*/
120+
121+
bool has_graph_arg = !tf.Context.executing_eagerly();
119122
foreach (var el in flatten_args)
120123
{
121124
if (el is Tensor tensor && !tensor.IsEagerTensor)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
public class ExecuteOpArgs
9+
{
10+
public Func<Operation, object> GetGradientAttrs { get; set; }
11+
public object[] OpInputArgs { get; set; }
12+
public Dictionary<string, object> OpAttrs { get; set; }
13+
14+
public ExecuteOpArgs(params object[] inputArgs)
15+
{
16+
OpInputArgs = inputArgs;
17+
}
18+
19+
public ExecuteOpArgs SetAttributes(object attrs)
20+
{
21+
OpAttrs = ConvertToDict(attrs);
22+
return this;
23+
}
24+
}
25+
}

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,7 @@ public IDatasetV2 apply_options()
105105
}
106106

107107
public Tensor dataset_cardinality(string name = null)
108-
{
109-
if (tf.Context.executing_eagerly())
110-
{
111-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
112-
"DatasetCardinality", name,
113-
null,
114-
variant_tensor);
115-
return results[0];
116-
}
117-
118-
throw new NotImplementedException("");
119-
}
108+
=> tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));
120109

121110
public override string ToString()
122111
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,84 +15,54 @@ namespace Tensorflow.Eager
1515
/// </summary>
1616
public partial class EagerRunner
1717
{
18-
int kFastPathExecuteInputStartIndex = 0;
1918
UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>();
2019

21-
public Tensor[] TFE_FastPathExecute2(Context ctx,
22-
string device_name,
23-
string opName,
24-
string name,
25-
Action callbacks,
26-
object[] inputArgs,
27-
object[] attrs)
20+
public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
2821
{
29-
var args = new List<object>();
30-
args.AddRange(inputArgs);
31-
if (attrs != null)
32-
args.AddRange(attrs);
33-
return TFE_FastPathExecute(ctx, device_name, opName, name, callbacks, args.ToArray());
34-
}
35-
36-
public Tensor[] TFE_FastPathExecute(Context ctx,
37-
string device_name,
38-
string opName,
39-
string name,
40-
Action callbacks,
41-
params object[] args)
42-
{
43-
if (ctx == null)
44-
throw new ValueError("This function does not handle the case of the path where " +
45-
"all inputs are not already EagerTensors.");
22+
if (op_exec_info.ctx == null)
23+
op_exec_info.ctx = tf.Context;
24+
if (string.IsNullOrEmpty(op_exec_info.device_name))
25+
op_exec_info.device_name = tf.Context.DeviceName;
4626

47-
int args_size = args.Length;
4827
var attr_list_sizes = new Dictionary<string, long>();
4928

50-
FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo()
51-
{
52-
ctx = ctx,
53-
args = args,
54-
device_name = device_name,
55-
op_name = opName,
56-
name = name,
57-
};
58-
5929
op_exec_info.run_gradient_callback = HasAccumulatorOrTape();
60-
op_exec_info.run_post_exec_callbacks = callbacks != null;
30+
op_exec_info.run_post_exec_callbacks = op_exec_info.callbacks != null;
6131
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks;
6232

6333
var status = tf.Status;
64-
using var op = GetOp(ctx, opName, status);
34+
using var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status);
6535

66-
var op_def = tf.get_default_graph().GetOpDef(opName);
36+
var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name);
6737

6838
var flattened_attrs = new List<object>(op_def.Attr.Count * 2);
6939
var flattened_inputs = new List<Tensor>(op_def.InputArg.Count);
7040

7141
// Set non-inferred attrs, including setting defaults if the attr is passed in
7242
// as None.
73-
for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
43+
if(op_exec_info.attrs != null)
7444
{
75-
var attr_name = args[i].ToString();
76-
var attr_value = args[i + 1];
77-
78-
var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name);
79-
if (attr != null)
45+
foreach (var attr1 in op_exec_info.attrs)
8046
{
81-
flattened_attrs.Add(attr_name);
82-
flattened_attrs.Add(attr_value);
47+
var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr1.Key);
48+
if (attr != null)
49+
{
50+
flattened_attrs.Add(attr.Name);
51+
flattened_attrs.Add(attr1.Value);
8352

84-
SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
85-
status.Check(true);
53+
SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr.Name, attr1.Value, attr_list_sizes, status);
54+
status.Check(true);
55+
}
8656
}
8757
}
8858

89-
c_api.TFE_OpSetDevice(op, device_name, status.Handle);
59+
c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle);
9060
status.Check(true);
9161

9262
// Add inferred attrs and inputs.
9363
for (int i = 0; i < op_def.InputArg.Count; i++)
9464
{
95-
var input = args[kFastPathExecuteInputStartIndex + i];
65+
var input = op_exec_info.args[i];
9666
var input_arg = op_def.InputArg[i];
9767
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
9868
{
@@ -107,7 +77,7 @@ public Tensor[] TFE_FastPathExecute(Context ctx,
10777

10878
if (len > 0)
10979
{
110-
var fast_input_array = (object[])args[i];
80+
var fast_input_array = (object[])op_exec_info.args[i];
11181
// First item adds the type attr.
11282
if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
11383
return null;
@@ -151,7 +121,7 @@ public Tensor[] TFE_FastPathExecute(Context ctx,
151121
else
152122
{
153123
// The item is a single item.
154-
AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
124+
AddInputToOp(op_exec_info.args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
155125
}
156126
}
157127

@@ -179,7 +149,7 @@ public Tensor[] TFE_FastPathExecute(Context ctx,
179149
if (op_exec_info.run_callbacks)
180150
{
181151
RunCallbacks(op_exec_info,
182-
kFastPathExecuteInputStartIndex + op_def.InputArg.Count(),
152+
op_def.InputArg.Count(),
183153
flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result);
184154
}
185155

0 commit comments

Comments
 (0)