Skip to content

Commit 717f7c5

Browse files
authored
[Op] Implement of SliceSend/SliceRecv Op. (#947)
Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
1 parent a5c014f commit 717f7c5

File tree

11 files changed

+1118
-5
lines changed

11 files changed

+1118
-5
lines changed

tensorflow/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,7 @@ tf_gen_op_libs(
12371237
"set_ops",
12381238
"script_ops",
12391239
"sendrecv_ops",
1240+
"slice_sendrecv_ops",
12401241
"sparse_ops",
12411242
"spectral_ops",
12421243
"state_ops",
@@ -1497,6 +1498,7 @@ cc_library(
14971498
":sdca_ops_op_lib",
14981499
":sendrecv_ops_op_lib",
14991500
":set_ops_op_lib",
1501+
":slice_sendrecv_ops_op_lib",
15001502
":sparse_ops_op_lib",
15011503
":star_run_graph_op_op_lib",
15021504
":summary_ops_op_lib",

tensorflow/core/framework/rendezvous.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class Rendezvous : public core::RefCounted {
8080
friend class SendOp;
8181
friend class RecvOp;
8282
friend class FuseRecvOp;
83+
friend class SliceSendOp;
84+
friend class SliceRecvOp;
8385
friend class RefSendOp;
8486
friend class RefRecvOp;
8587
string buf_;

tensorflow/core/graph/graph.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,13 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
6969
{"_Send", NC_SEND},
7070
{"_HostSend", NC_HOST_SEND},
7171
{"_RefSend", NC_REF_SEND},
72+
{"_SliceSend", NC_SLICE_SEND},
7273
{"_Recv", NC_RECV},
7374
{"_HostRecv", NC_HOST_RECV},
7475
{"_RefRecv", NC_REF_RECV},
7576
{"_FuseRecv", NC_FUSE_RECV},
7677
{"_HostFuseRecv", NC_HOST_FUSE_RECV},
78+
{"_SliceRecv", NC_SLICE_RECV},
7779
{"Const", NC_CONSTANT},
7880
{"HostConst", NC_CONSTANT},
7981
{"Variable", NC_VARIABLE},

tensorflow/core/graph/graph.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,16 @@ class Node {
219219
bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
220220
bool IsSend() const { return class_ == NC_SEND ||
221221
class_ == NC_HOST_SEND ||
222-
class_ == NC_REF_SEND; }
222+
class_ == NC_REF_SEND ||
223+
class_ == NC_SLICE_SEND; }
224+
bool IsSliceSend() const { return class_ == NC_SLICE_SEND; }
223225
bool IsRecv() const { return class_ == NC_RECV ||
224226
class_ == NC_HOST_RECV ||
225-
class_ == NC_REF_RECV; }
227+
class_ == NC_REF_RECV ||
228+
class_ == NC_SLICE_RECV; }
226229
bool IsFuseRecv() const { return class_ == NC_FUSE_RECV ||
227230
class_ == NC_HOST_FUSE_RECV; }
231+
bool IsSliceRecv() const {return class_ == NC_SLICE_RECV; }
228232
bool IsConstant() const { return class_ == NC_CONSTANT; }
229233
bool IsStage() const { return class_ == NC_TENSOR_BUFFER_PUT; }
230234
bool IsUnstage() const { return class_ == NC_TENSOR_BUFFER_TAKE; }
@@ -334,11 +338,13 @@ class Node {
334338
NC_SEND,
335339
NC_HOST_SEND,
336340
NC_REF_SEND,
341+
NC_SLICE_SEND,
337342
NC_RECV,
338343
NC_HOST_RECV,
339344
NC_REF_RECV,
340345
NC_FUSE_RECV,
341346
NC_HOST_FUSE_RECV,
347+
NC_SLICE_RECV,
342348
NC_CONSTANT,
343349
NC_VARIABLE,
344350
NC_KV_VAR_HANDLE,
@@ -844,7 +850,9 @@ inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
844850
inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
845851
inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
846852
inline bool IsSend(const Node* node) { return node->IsSend(); }
853+
inline bool IsSliceSend(const Node* node) { return node->IsSliceSend(); }
847854
inline bool IsRecv(const Node* node) { return node->IsRecv(); }
855+
inline bool IsSliceRecv(const Node* node) { return node->IsSliceRecv(); }
848856
inline bool IsFuseRecv(const Node* node) { return node->IsFuseRecv(); }
849857
inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
850858
inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }

tensorflow/core/grappler/op_types.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ bool IsReciprocalGrad(const NodeDef& node) {
454454
}
455455

456456
bool IsRecv(const NodeDef& node) {
457-
return node.op() == "_Recv" || node.op() == "_HostRecv";
457+
return node.op() == "_Recv" || node.op() == "_HostRecv" || IsSliceRecv(node);
458458
}
459459

460460
bool IsFuseRecv(const NodeDef& node) {
@@ -502,7 +502,7 @@ bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
502502
bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
503503

504504
bool IsSend(const NodeDef& node) {
505-
return node.op() == "_Send" || node.op() == "_HostSend";
505+
return node.op() == "_Send" || node.op() == "_HostSend" || IsSliceSend(node);
506506
}
507507

508508
bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
@@ -517,6 +517,10 @@ bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
517517

518518
bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
519519

520+
bool IsSliceRecv(const NodeDef& node) { return node.op() == "_SliceRecv"; }
521+
522+
bool IsSliceSend(const NodeDef& node) { return node.op() == "_SliceSend"; }
523+
520524
bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
521525

522526
bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }

tensorflow/core/grappler/op_types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ bool IsShuffle(const NodeDef& node);
167167
bool IsSigmoidGrad(const NodeDef& node);
168168
bool IsSize(const NodeDef& node);
169169
bool IsSlice(const NodeDef& node);
170+
bool IsSliceRecv(const NodeDef& node);
171+
bool IsSliceSend(const NodeDef& node);
170172
bool IsSnapshot(const NodeDef& node);
171173
bool IsSoftmax(const NodeDef& node);
172174
bool IsSoftplusGrad(const NodeDef& node);

tensorflow/core/kernels/BUILD

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5423,8 +5423,9 @@ cc_library(
54235423
name = "required",
54245424
deps = [
54255425
":no_op",
5426-
":sendrecv_ops",
54275426
":fuserecv_ops",
5427+
":sendrecv_ops",
5428+
":slice_sendrecv_ops",
54285429
],
54295430
)
54305431

@@ -5445,6 +5446,12 @@ tf_kernel_library(
54455446
deps = REQUIRED_DEPS,
54465447
)
54475448

5449+
tf_kernel_library(
5450+
name = "slice_sendrecv_ops",
5451+
prefix = "slice_sendrecv_ops",
5452+
deps = REQUIRED_DEPS,
5453+
)
5454+
54485455
tf_kernel_library(
54495456
name = "group_embedding_ops",
54505457
hdrs = ["group_embedding/group_embedding_lookup_sparse_forward_base_ops.h"],
@@ -5509,6 +5516,24 @@ tf_cc_test(
55095516
],
55105517
)
55115518

5519+
tf_cc_test(
5520+
name = "slice_sendrecv_ops_test",
5521+
srcs = ["slice_sendrecv_ops_test.cc"],
5522+
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
5523+
deps = [
5524+
":control_flow_ops",
5525+
":cwise_op",
5526+
":logging_ops",
5527+
":ops_testutil",
5528+
":ops_util",
5529+
":slice_sendrecv_ops",
5530+
"//tensorflow/core:framework",
5531+
"//tensorflow/core:test",
5532+
"//tensorflow/core:test_main",
5533+
"//tensorflow/core:testlib",
5534+
],
5535+
)
5536+
55125537
tf_kernel_library(
55135538
name = "fuserecv_ops",
55145539
prefix = "fuserecv_ops",

0 commit comments

Comments
 (0)