@@ -34,11 +34,13 @@ limitations under the License.
3434#include " tensorflow/core/platform/logging.h"
3535#include " tensorflow/core/platform/mutex.h"
3636#include " tensorflow/core/platform/types.h"
37+ #include " tensorflow/core/util/env_var.h"
3738
3839namespace tensorflow {
3940
4041namespace {
4142 uint64 kGlobalStepId = 0x100000000000000uLL;
43+ int64 kFlowControlMaxSize = 16 ;
4244} // namespace anonymous
4345
4446static void StartAbortRendevous (Rendezvous* rendez, const Status& s) {
@@ -127,6 +129,23 @@ void BaseRendezvousMgr::FuseRecvLocalAsync(
127129 rendez->FuseRecvLocalAsync (parsed_keys, std::move (done_cb));
128130}
129131
132+ void BaseRendezvousMgr::FlowControlRecvLocalAsync (int64 step_id,
133+ const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
134+ Rendezvous::DoneCallback done) {
135+ auto rendez = FindOrCreate (step_id);
136+ using namespace std ::placeholders;
137+ Rendezvous::DoneCallback done_cb = std::bind (
138+ [rendez](Rendezvous::DoneCallback done,
139+ // Begin unbound arguments.
140+ const Status& s, const Rendezvous::Args& send_args,
141+ const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
142+ rendez->Unref ();
143+ done (s, send_args, recv_args, v, dead);
144+ },
145+ std::move (done), _1, _2, _3, _4, _5);
146+ rendez->FlowControlRecvLocalAsync (tag, parsed, std::move (done_cb));
147+ }
148+
130149void BaseRendezvousMgr::Cleanup (int64 step_id) {
131150 Rendezvous* rendez = nullptr ;
132151 {
@@ -174,7 +193,17 @@ BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
174193 : env_(env),
175194 step_id_ (step_id),
176195 local_(NewLocalRendezvous()),
177- session_(nullptr ) {}
196+ session_(nullptr ),
197+ flow_control_num_(0 ) {
198+ Status s = ReadInt64FromEnvVar (" REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE" ,
199+ kFlowControlMaxSize , &flow_control_max_size_);
200+ if (!s.ok ()) {
201+ LOG (ERROR) << " Read REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE env error: "
202+ << s.error_message ();
203+ }
204+ VLOG (2 ) << " BaseRemoteRendezvous set flow control max size: "
205+ << flow_control_max_size_;
206+ }
178207
179208BaseRemoteRendezvous::~BaseRemoteRendezvous () {
180209 CHECK (active_.empty ());
@@ -221,6 +250,16 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
221250 std::move (fuse_call.done ));
222251 }
223252
253+ std::vector<DeferredFlowControlCall> deferred_flow_control_calls;
254+ {
255+ mutex_lock l (mu_);
256+ std::swap (deferred_flow_control_calls, deferred_flow_control_calls_);
257+ }
258+ for (auto & fc_call : deferred_flow_control_calls) {
259+ FlowControlRecvLocalAsyncInternal (fc_call.tag , fc_call.parsed ,
260+ std::move (fc_call.done ));
261+ }
262+
224263 return Status::OK ();
225264}
226265
@@ -271,6 +310,43 @@ Status BaseRemoteRendezvous::Send(const ParsedKey& parsed,
271310 return local_->Send (parsed, args, val, mu, is_dead);
272311}
273312
313+ Status BaseRemoteRendezvous::FlowControlSend (const StringPiece& tag,
314+ const ParsedKey& parsed,
315+ const Args& args,
316+ const Tensor& val,
317+ const bool is_dead,
318+ const int64 timeout_millis) {
319+ VLOG (1 ) << " BaseRemoteRendezvous FlowControlSend " << this << " "
320+ << parsed.FullKey ();
321+ const std::string tag_string (tag.data (), tag.size ());
322+ {
323+ mutex_lock l (mu_);
324+ while (status_.ok () && flow_control_num_ >= flow_control_max_size_) {
325+ if (flow_control_cv_.wait_for (
326+ l, std::chrono::milliseconds (timeout_millis)) == \
327+ std::cv_status::timeout) {
328+ return errors::DeadlineExceeded (" FlowControlSend has timed out." );
329+ }
330+ }
331+
332+ if (!status_.ok ()) return status_;
333+ DCHECK (is_initialized_locked ());
334+ if (!IsLocalDevice (session_->worker_name , parsed.src_device )) {
335+ return errors::InvalidArgument (
336+ " Invalid rendezvous key (src): " , parsed.FullKey (), " @ " ,
337+ session_->worker_name );
338+ }
339+
340+ flow_control_num_++;
341+ if (flow_control_counters_.count (tag_string) == 0 ) {
342+ flow_control_counters_[tag_string] = 0 ;
343+ }
344+ flow_control_counters_[tag_string]++;
345+ }
346+ // Buffers "val" and "device_context" in local_.
347+ return local_->Send (parsed, args, val, is_dead);
348+ }
349+
274350Status BaseRemoteRendezvous::ValidateDevices (const ParsedKey& parsed,
275351 bool is_src) {
276352 // Cache session pointer to avoid repeatedly taking & releasing the lock
@@ -413,6 +489,63 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
413489 }
414490}
415491
492+ void BaseRemoteRendezvous::FlowControlRecvAsync (const StringPiece& tag,
493+ const ParsedKey& parsed,
494+ const Args& recv_args,
495+ DoneCallback done) {
496+ VLOG (1 ) << " RemoteRendezvous FlowControlRecvAsync " << this
497+ << " " << tag << " " << parsed.FullKey ();
498+
499+ Status s = ValidateDevices (parsed, false /* !is_src*/ );
500+ if (s.ok () && !is_initialized ()) {
501+ s.Update (errors::Internal (
502+ " FlowControlRecvAsync called when uninitialized (key:" ,
503+ parsed.FullKey (), " )." ));
504+ }
505+ if (!s.ok ()) {
506+ done (s, Args (), recv_args, Tensor (), false );
507+ return ;
508+ }
509+
510+ // Are src and dst in the same worker?
511+ if (IsSameWorker (parsed.src , parsed.dst )) {
512+ // Recv the tensor from local_.
513+ local_->RecvAsync (
514+ parsed, recv_args,
515+ [this , tag, parsed, done](
516+ const Status& status, const Rendezvous::Args& send_args,
517+ const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
518+ VLOG (2 ) << " RemoteRendezvous Finished Recv " << this << " "
519+ << parsed.FullKey ();
520+ Tensor* out = new Tensor;
521+ StatusCallback final_callback = [done, send_args, recv_args, out,
522+ is_dead](const Status& s) {
523+ done (s, send_args, recv_args, *out, is_dead);
524+ delete out;
525+ };
526+
527+ if (status.ok ()) {
528+ SameWorkerRecvDone (parsed, send_args, recv_args, in, out,
529+ std::move (final_callback));
530+ const std::string tag_string (tag.data (), tag.size ());
531+ {
532+ mutex_lock l (mu_);
533+ flow_control_num_--;
534+ DCHECK (flow_control_counters_.count (tag_string) != 0 );
535+ flow_control_counters_[tag_string]--;
536+ }
537+ flow_control_cv_.notify_one ();
538+ } else {
539+ final_callback (status);
540+ }
541+ });
542+ return ;
543+ } else {
544+ FlowControlRecvFromRemoteAsync (tag, parsed, recv_args, std::move (done));
545+ }
546+
547+ }
548+
416549void BaseRemoteRendezvous::RecvLocalAsync (const ParsedKey& parsed,
417550 DoneCallback done) {
418551 {
@@ -600,13 +733,71 @@ void BaseRemoteRendezvous::FuseRecvLocalAsyncInternal(
600733 }
601734}
602735
736+ void BaseRemoteRendezvous::FlowControlRecvLocalAsync (const StringPiece& tag,
737+ const ParsedKey& parsed,
738+ DoneCallback done) {
739+ {
740+ mutex_lock l (mu_);
741+ if (!is_initialized_locked ()) {
742+ // FlowControlRecvLocalAsync can be called (due to an incoming RecvTensor
743+ // RPC from a remote worker) before the RunStep (or PartialRunStep) RPC
744+ // from the master arrives. RecvLocalAsync thus buffers the arguments
745+ // until after the RemoteRendezvous is Initialize()'d, when it completes
746+ // the rendezvous logic. At some point after Initialize() is called, a
747+ // Tensor is produced locally that will then be sent in response to the
748+ // incoming RPC.
749+ DeferredFlowControlCall call (tag, parsed, std::move (done));
750+ deferred_flow_control_calls_.push_back (call);
751+ return ;
752+ }
753+ }
754+ FlowControlRecvLocalAsyncInternal (tag, parsed, std::move (done));
755+ }
756+
757+ void BaseRemoteRendezvous::FlowControlRecvLocalAsyncInternal (
758+ const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) {
759+ Status s = ValidateDevices (parsed, true /* is_src */ );
760+ if (!s.ok ()) {
761+ done (s, Args (), Args (), Tensor (), false );
762+ return ;
763+ }
764+
765+ using namespace std ::placeholders;
766+ Rendezvous::DoneCallback done_cb = std::bind (
767+ [this , tag](Rendezvous::DoneCallback done,
768+ // Begin unbound arguments.
769+ const Status& s, const Rendezvous::Args& send_args,
770+ const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
771+ done (s, send_args, recv_args, v, dead);
772+ if (s.ok ()) {
773+ const std::string tag_string (tag.data (), tag.size ());
774+ {
775+ mutex_lock l (mu_);
776+ flow_control_num_--;
777+ DCHECK (flow_control_counters_.count (tag_string) != 0 );
778+ flow_control_counters_[tag_string]--;
779+ }
780+ flow_control_cv_.notify_one ();
781+ }
782+ },
783+ std::move (done), _1, _2, _3, _4, _5);
784+
785+ local_->RecvAsync (parsed, Args (), std::move (done_cb));
786+ }
787+
603788void BaseRemoteRendezvous::FuseRecvFromRemoteAsync (
604789 const std::vector<Rendezvous::ParsedKey>& parsed_keys,
605790 const Rendezvous::Args& args,
606791 FuseDoneCallback done) {
607792 CHECK (false ) << " FuseRecvFromRemoteAsync Unimplemented" ;
608793}
609794
795+ void BaseRemoteRendezvous::FlowControlRecvFromRemoteAsync (
796+ const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
797+ const Rendezvous::Args& args, DoneCallback done) {
798+ CHECK (false ) << " FlowControlRecvFromRemoteAsync Unimplemented." ;
799+ }
800+
610801void BaseRemoteRendezvous::RecvAsync (const ParsedKey& parsed,
611802 const Rendezvous::Args& recv_args,
612803 RefDoneCallback done) {
@@ -636,6 +827,19 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
636827 }
637828}
638829
830+ int64 BaseRemoteRendezvous::GetAllFlowControlItemNum () {
831+ mutex_lock l (mu_);
832+ return flow_control_num_;
833+ }
834+
835+ int64 BaseRemoteRendezvous::GetFlowControlItemNum (StringPiece tag) {
836+ const std::string tag_string (tag.data (), tag.size ());
837+ mutex_lock l (mu_);
838+ if (flow_control_counters_.count (tag_string) == 0 )
839+ return 0 ;
840+ return flow_control_counters_[tag_string];
841+ }
842+
639843void BaseRemoteRendezvous::StartAbort (const Status& s) {
640844 CHECK (!s.ok ());
641845 // Use a "derived" status as the status for the rendezvous. Derived
@@ -656,7 +860,10 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
656860 }
657861 active_.clear ();
658862 }
863+ flow_control_num_ = 0 ;
864+ flow_control_counters_.clear ();
659865 }
866+ flow_control_cv_.notify_all ();
660867}
661868
662869void BaseRemoteRendezvous::RegisterCall (BaseRecvTensorCall* call,
@@ -707,4 +914,8 @@ BaseRemoteRendezvous::DeferredFuseCall::DeferredFuseCall(
707914 const std::vector<ParsedKey>& parsed_keys, FuseDoneCallback done)
708915 : parsed_keys(parsed_keys), done(std::move(done)) {}
709916
917+ BaseRemoteRendezvous::DeferredFlowControlCall::DeferredFlowControlCall (
918+ const StringPiece& tag, const ParsedKey& parsed, DoneCallback done)
919+ : tag(tag), parsed(parsed), done(std::move(done)) {}
920+
710921} // end namespace tensorflow
0 commit comments