Skip to content

Commit 0f536a2

Browse files
authored
[Op] Implement FileSliceSend/FileSliceRecvOp. (#960)
FileSliceSend/FileSliceRecv Op transfer scalar string Tensor to/from SliceRecv/SliceSend Op. Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
1 parent 6bf5621 commit 0f536a2

15 files changed

+1388
-103
lines changed

tensorflow/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,7 @@ tf_gen_op_libs(
12031203
"encode_proto_ops",
12041204
"experimental_dataset_ops",
12051205
"feature_column_ops",
1206+
"file_slice_sendrecv_ops",
12061207
"function_ops",
12071208
"functional_ops",
12081209
"fused_embedding_ops",
@@ -1465,6 +1466,7 @@ cc_library(
14651466
":encode_proto_ops_op_lib",
14661467
":experimental_dataset_ops_op_lib",
14671468
":feature_column_ops_op_lib",
1469+
":file_slice_sendrecv_ops_op_lib",
14681470
":function_ops_op_lib",
14691471
":functional_ops_op_lib",
14701472
":fused_embedding_ops_op_lib",

tensorflow/core/framework/rendezvous.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class Rendezvous : public core::RefCounted {
8282
friend class FuseRecvOp;
8383
friend class SliceSendOp;
8484
friend class SliceRecvOp;
85+
friend class FileSliceSendOp;
86+
friend class FileSliceRecvOp;
8587
friend class RefSendOp;
8688
friend class RefRecvOp;
8789
string buf_;

tensorflow/core/graph/graph.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
7070
{"_HostSend", NC_HOST_SEND},
7171
{"_RefSend", NC_REF_SEND},
7272
{"_SliceSend", NC_SLICE_SEND},
73+
{"_FileSliceSend", NC_FILE_SLICE_SEND},
7374
{"_Recv", NC_RECV},
7475
{"_HostRecv", NC_HOST_RECV},
7576
{"_RefRecv", NC_REF_RECV},
7677
{"_FuseRecv", NC_FUSE_RECV},
7778
{"_HostFuseRecv", NC_HOST_FUSE_RECV},
7879
{"_SliceRecv", NC_SLICE_RECV},
80+
{"_FileSliceRecv", NC_FILE_SLICE_RECV},
7981
{"Const", NC_CONSTANT},
8082
{"HostConst", NC_CONSTANT},
8183
{"Variable", NC_VARIABLE},

tensorflow/core/graph/graph.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,19 @@ class Node {
220220
bool IsSend() const { return class_ == NC_SEND ||
221221
class_ == NC_HOST_SEND ||
222222
class_ == NC_REF_SEND ||
223-
class_ == NC_SLICE_SEND; }
223+
class_ == NC_SLICE_SEND ||
224+
class_ == NC_FILE_SLICE_SEND; }
224225
bool IsSliceSend() const { return class_ == NC_SLICE_SEND; }
226+
bool IsFileSliceSend() const { return class_ == NC_FILE_SLICE_SEND; }
225227
bool IsRecv() const { return class_ == NC_RECV ||
226228
class_ == NC_HOST_RECV ||
227229
class_ == NC_REF_RECV ||
228-
class_ == NC_SLICE_RECV; }
230+
class_ == NC_SLICE_RECV ||
231+
class_ == NC_FILE_SLICE_RECV; }
229232
bool IsFuseRecv() const { return class_ == NC_FUSE_RECV ||
230233
class_ == NC_HOST_FUSE_RECV; }
231234
bool IsSliceRecv() const {return class_ == NC_SLICE_RECV; }
235+
bool IsFileSliceRecv() const { return class_ == NC_FILE_SLICE_RECV; }
232236
bool IsConstant() const { return class_ == NC_CONSTANT; }
233237
bool IsStage() const { return class_ == NC_TENSOR_BUFFER_PUT; }
234238
bool IsUnstage() const { return class_ == NC_TENSOR_BUFFER_TAKE; }
@@ -339,12 +343,14 @@ class Node {
339343
NC_HOST_SEND,
340344
NC_REF_SEND,
341345
NC_SLICE_SEND,
346+
NC_FILE_SLICE_SEND,
342347
NC_RECV,
343348
NC_HOST_RECV,
344349
NC_REF_RECV,
345350
NC_FUSE_RECV,
346351
NC_HOST_FUSE_RECV,
347352
NC_SLICE_RECV,
353+
NC_FILE_SLICE_RECV,
348354
NC_CONSTANT,
349355
NC_VARIABLE,
350356
NC_KV_VAR_HANDLE,
@@ -851,8 +857,10 @@ inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
851857
inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
852858
inline bool IsSend(const Node* node) { return node->IsSend(); }
853859
inline bool IsSliceSend(const Node* node) { return node->IsSliceSend(); }
860+
inline bool IsFileSliceSend(const Node* node) { return node->IsFileSliceSend(); }
854861
inline bool IsRecv(const Node* node) { return node->IsRecv(); }
855862
inline bool IsSliceRecv(const Node* node) { return node->IsSliceRecv(); }
863+
inline bool IsFileSliceRecv(const Node* node) { return node->IsFileSliceRecv(); }
856864
inline bool IsFuseRecv(const Node* node) { return node->IsFuseRecv(); }
857865
inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
858866
inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }

tensorflow/core/grappler/op_types.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
265265

266266
bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
267267

268+
bool IsFileSliceRecv(const NodeDef& node) { return node.op() == "_FileSliceRecv"; }
269+
270+
bool IsFileSliceSend(const NodeDef& node) { return node.op() == "_FileSliceSend"; }
271+
268272
bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
269273

270274
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
@@ -454,7 +458,8 @@ bool IsReciprocalGrad(const NodeDef& node) {
454458
}
455459

456460
bool IsRecv(const NodeDef& node) {
457-
return node.op() == "_Recv" || node.op() == "_HostRecv" || IsSliceRecv(node);
461+
return node.op() == "_Recv" || node.op() == "_HostRecv" ||
462+
IsSliceRecv(node) || IsFileSliceRecv(node);
458463
}
459464

460465
bool IsFuseRecv(const NodeDef& node) {
@@ -502,7 +507,8 @@ bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
502507
bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
503508

504509
bool IsSend(const NodeDef& node) {
505-
return node.op() == "_Send" || node.op() == "_HostSend" || IsSliceSend(node);
510+
return node.op() == "_Send" || node.op() == "_HostSend" ||
511+
IsSliceSend(node) || IsFileSliceSend(node);
506512
}
507513

508514
bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }

tensorflow/core/grappler/op_types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ bool IsExit(const NodeDef& node);
8080
bool IsExp(const NodeDef& node);
8181
bool IsFakeParam(const NodeDef& node);
8282
bool IsFill(const NodeDef& node);
83+
bool IsFileSliceRecv(const NodeDef& node);
84+
bool IsFileSliceSend(const NodeDef& node);
8385
bool IsFloorDiv(const NodeDef& node);
8486
bool IsFloorMod(const NodeDef& node);
8587
bool IsFusedBatchNorm(const NodeDef& node);

tensorflow/core/kernels/BUILD

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5423,6 +5423,7 @@ cc_library(
54235423
name = "required",
54245424
deps = [
54255425
":no_op",
5426+
":file_slice_sendrecv_ops",
54265427
":fuserecv_ops",
54275428
":sendrecv_ops",
54285429
":slice_sendrecv_ops",
@@ -5446,10 +5447,33 @@ tf_kernel_library(
54465447
deps = REQUIRED_DEPS,
54475448
)
54485449

5450+
cc_library(
5451+
name = "slice_sendrecv_utils",
5452+
hdrs = [
5453+
"slice_sendrecv_utils.h"
5454+
],
5455+
srcs = [
5456+
"slice_sendrecv_utils.cc",
5457+
],
5458+
deps = [
5459+
"//tensorflow/core:framework",
5460+
]
5461+
)
5462+
54495463
tf_kernel_library(
54505464
name = "slice_sendrecv_ops",
54515465
prefix = "slice_sendrecv_ops",
5452-
deps = REQUIRED_DEPS,
5466+
deps = REQUIRED_DEPS + [
5467+
":slice_sendrecv_utils",
5468+
],
5469+
)
5470+
5471+
tf_kernel_library(
5472+
name = "file_slice_sendrecv_ops",
5473+
prefix = "file_slice_sendrecv_ops",
5474+
deps = REQUIRED_DEPS + [
5475+
":slice_sendrecv_utils",
5476+
],
54535477
)
54545478

54555479
tf_kernel_library(
@@ -5534,6 +5558,26 @@ tf_cc_test(
55345558
],
55355559
)
55365560

5561+
tf_cc_test(
5562+
name = "file_slice_sendrecv_ops_test",
5563+
srcs = ["file_slice_sendrecv_ops_test.cc"],
5564+
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
5565+
deps = [
5566+
":control_flow_ops",
5567+
":cwise_op",
5568+
":file_slice_sendrecv_ops",
5569+
":logging_ops",
5570+
":ops_testutil",
5571+
":ops_util",
5572+
":slice_sendrecv_ops",
5573+
":whole_file_read_ops",
5574+
"//tensorflow/core:framework",
5575+
"//tensorflow/core:test",
5576+
"//tensorflow/core:test_main",
5577+
"//tensorflow/core:testlib",
5578+
],
5579+
)
5580+
55375581
tf_kernel_library(
55385582
name = "fuserecv_ops",
55395583
prefix = "fuserecv_ops",

0 commit comments

Comments
 (0)