@@ -36,6 +36,7 @@ limitations under the License.
3636#include " tensorflow/core/platform/mem.h"
3737#include " tensorflow/core/platform/mutex.h"
3838#include " tensorflow/core/platform/types.h"
39+ #include " tensorflow/core/util/env_var.h"
3940#include " tensorflow/core/util/util.h"
4041#include " tensorflow/core/util/work_sharder.h"
4142#if GOOGLE_CUDA
@@ -126,36 +127,21 @@ class InitializeKvVariableOp : public OpKernel {
126127 OP_REQUIRES_OK (c, c->GetAttr (" shape" , &shape_));
127128 OP_REQUIRES (c, shape_.dims () == 1 ,
128129 errors::InvalidArgument (" KvVariable dimension must be 1" ));
129-
130- // get ev emb_index
131130 OP_REQUIRES_OK (c, c->GetAttr (" emb_index" , &emb_index_));
132- // get ev block_num
133131 OP_REQUIRES_OK (c, c->GetAttr (" block_num" , &block_num_));
134- // get ev slot_index
135132 OP_REQUIRES_OK (c, c->GetAttr (" slot_index" , &slot_index_));
136-
137133 OP_REQUIRES_OK (c, c->GetAttr (" steps_to_live" , &steps_to_live_));
138-
139134 OP_REQUIRES_OK (c, c->GetAttr (" filter_freq" , &filter_freq_));
140-
141135 OP_REQUIRES_OK (c, c->GetAttr (" max_freq" , &max_freq_));
142-
143136 OP_REQUIRES_OK (c, c->GetAttr (" max_element_size" , &max_element_size_));
144-
145137 OP_REQUIRES_OK (c, c->GetAttr (" false_positive_probability" ,
146138 &false_positive_probability_));
147-
148139 OP_REQUIRES_OK (c, c->GetAttr (" l2_weight_threshold" ,
149140 &l2_weight_threshold_));
150-
151141 OP_REQUIRES_OK (c, c->GetAttr (" layout" , &layout_));
152-
153142 OP_REQUIRES_OK (c, c->GetAttr (" default_value_dim" , &default_value_dim_));
154-
155143 OP_REQUIRES_OK (c, c->GetAttr (" slot_num" , &slot_num_));
156-
157144 OP_REQUIRES_OK (c, c->GetAttr (" record_freq" , &record_freq_));
158-
159145 OP_REQUIRES_OK (c, c->GetAttr (" record_version" , &record_version_));
160146
161147 int64 storage_type = 0 ;
@@ -592,10 +578,34 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS_ALL_INDEX);
592578#undef REGISTER_KERNELS_ALL_INDEX
593579#undef REGISTER_KERNELS
594580*/
581+
582+ constexpr int64 DEFAULT_RESTORE_THREAD_NUM = 4 ;
583+
584+ class KvRestoreThreadPool {
585+ public:
586+ KvRestoreThreadPool () {
587+ TF_CHECK_OK (ReadInt64FromEnvVar (" TF_EV_RESTORE_THREAD_NUM" ,
588+ DEFAULT_RESTORE_THREAD_NUM, &thread_num_));
589+ }
590+
591+ static thread::ThreadPool* GetInstance () {
592+ static thread::ThreadPool tp (Env::Default (),
593+ " restore_ev_threadpool" , thread_num_);
594+ return &tp;
595+ }
596+
597+ private:
598+ static int64 thread_num_;
599+ };
600+
601+ int64 KvRestoreThreadPool::thread_num_ =
602+ DEFAULT_RESTORE_THREAD_NUM;
603+
595604template <typename TKey, typename TValue>
596- class KvResourceImportV2Op : public OpKernel {
605+ class KvResourceImportV2Op : public AsyncOpKernel {
597606 public:
598- explicit KvResourceImportV2Op (OpKernelConstruction* c) : OpKernel(c) {
607+ explicit KvResourceImportV2Op (OpKernelConstruction* c)
608+ : AsyncOpKernel(c) {
599609 OP_REQUIRES_OK (c, c->GetAttr (" dtype" , &dtype_));
600610 OP_REQUIRES_OK (c, c->GetAttr (" counter_type" , &counter_type_));
601611 OP_REQUIRES_OK (c, c->GetAttr (" shape" , &shape_));
@@ -619,15 +629,11 @@ class KvResourceImportV2Op: public OpKernel {
619629 // OP_REQUIRES_OK(c, c->GetAttr("restore_versions", &restore_versions_));
620630 OP_REQUIRES_OK (c, c->GetAttr (" ht_type" , &ht_type_));
621631 OP_REQUIRES_OK (c, c->GetAttr (" ht_partition_num" , &ht_partition_num_));
622- // get ev emb_index
623632 OP_REQUIRES_OK (c, c->GetAttr (" emb_index" , &emb_index_));
624- // get ev slot_index
625633 OP_REQUIRES_OK (c, c->GetAttr (" slot_index" , &slot_index_));
626634 OP_REQUIRES_OK (c, c->GetAttr (" filter_freq" , &filter_freq_));
627635 OP_REQUIRES_OK (c, c->GetAttr (" block_num" , &block_num_));
628-
629636 OP_REQUIRES_OK (c, c->GetAttr (" max_element_size" , &max_element_size_));
630-
631637 OP_REQUIRES_OK (c, c->GetAttr (" false_positive_probability" ,
632638 &false_positive_probability_));
633639 OP_REQUIRES_OK (c, c->GetAttr (" l2_weight_threshold" ,
@@ -636,27 +642,26 @@ class KvResourceImportV2Op: public OpKernel {
636642 OP_REQUIRES_OK (c, c->GetAttr (" max_freq" , &max_freq_));
637643 OP_REQUIRES_OK (c, c->GetAttr (" default_value_dim" ,
638644 &default_value_dim_));
639-
640645 OP_REQUIRES_OK (c, c->GetAttr (" slot_num" , &slot_num_));
641-
642646 int64 storage_type = 0 ;
643647 OP_REQUIRES_OK (c, c->GetAttr (" storage_type" , &storage_type));
644648 storage_type_ = static_cast <embedding::StorageType>(storage_type);
645649
646650 OP_REQUIRES_OK (c, c->GetAttr (" storage_path" , &storage_path_));
647651 OP_REQUIRES_OK (c, c->GetAttr (" storage_size" , &storage_size_));
648-
649652 OP_REQUIRES_OK (c, c->GetAttr (" record_freq" , &record_freq_));
650-
651653 OP_REQUIRES_OK (c, c->GetAttr (" record_version" , &record_version_));
654+
655+ TF_CHECK_OK (ReadBoolFromEnvVar (" TF_ENABLE_EV_ASYNC_RESTORE" , true ,
656+ &ev_async_restore_));
652657 }
653658
654- void Compute (OpKernelContext* context) override {
659+ void ComputeAsync (OpKernelContext* context, DoneCallback done ) override {
655660 const Tensor& file_name = context->input (0 );
656661 const std::string file_name_string = file_name.scalar <string>()();
657662 const Tensor& name = context->input (4 );
658663 const std::string name_string = name.scalar <string>()();
659- const Tensor& default_values = context->input (3 );
664+ const Tensor& default_values = context->input (3 );
660665 OP_REQUIRES (context, dtype_ == default_values.dtype (),
661666 errors::InvalidArgument (
662667 " Variable and ddd value dtypes don't match; respectively, " ,
@@ -744,15 +749,29 @@ class KvResourceImportV2Op: public OpKernel {
744749 }
745750 core::ScopedUnref unref_me (ev);
746751
747- BundleReader reader (Env::Default (), file_name_string);
748- OP_REQUIRES_OK (context, reader.status ());
752+ auto do_compute = [this , context, file_name_string, ev,
753+ name_string, done] () {
754+ BundleReader reader (Env::Default (), file_name_string);
755+ auto s = reader.status ();
756+ if (!s.ok ()) {
757+ LOG (FATAL) << " Restore EV failure, create BundleReader error:"
758+ << s.ToString ();
759+ }
749760
750- EVRestoreDynamically (
751- ev, name_string, partition_id_, partition_num_, context, &reader,
752- " -partition_offset" , " -keys" , " -values" , " -versions" , " -freqs" );
753- ev->SetInitialized ();
754- }
761+ EVRestoreDynamically (
762+ ev, name_string, partition_id_, partition_num_, context, &reader,
763+ " -partition_offset" , " -keys" , " -values" , " -versions" , " -freqs" );
764+ ev->SetInitialized ();
765+ done ();
766+ };
755767
768+ if (ev_async_restore_) {
769+ auto tp = KvRestoreThreadPool::GetInstance ();
770+ tp->Schedule (do_compute);
771+ } else {
772+ do_compute ();
773+ }
774+ }
756775
757776 private:
758777 int64 partition_id_;
@@ -780,6 +799,7 @@ class KvResourceImportV2Op: public OpKernel {
780799 int64 default_value_dim_;
781800 bool record_freq_;
782801 bool record_version_;
802+ bool ev_async_restore_;
783803};
784804
785805#define REGISTER_KERNELS (ktype, vtype ) \
@@ -798,9 +818,10 @@ TF_CALL_double(REGISTER_KERNELS_ALL_INDEX);
798818#undef REGISTER_KERNELS
799819
800820template <typename TKey, typename TValue>
801- class KvResourceIncrImportOp : public OpKernel {
821+ class KvResourceIncrImportOp : public AsyncOpKernel {
802822 public:
803- explicit KvResourceIncrImportOp (OpKernelConstruction* c) : OpKernel(c) {
823+ explicit KvResourceIncrImportOp (OpKernelConstruction* c)
824+ : AsyncOpKernel(c) {
804825 OP_REQUIRES_OK (c, c->GetAttr (" dtype" , &dtype_));
805826
806827 OP_REQUIRES_OK (c, c->GetAttr (" partition_id" , &partition_id_));
@@ -814,7 +835,7 @@ class KvResourceIncrImportOp: public OpKernel {
814835
815836 }
816837
817- void Compute (OpKernelContext* context) override {
838+ void ComputeAsync (OpKernelContext* context, DoneCallback done ) override {
818839 const Tensor& file_name = context->input (0 );
819840 const std::string file_name_string = file_name.scalar <string>()();
820841 const Tensor& name = context->input (2 );
@@ -833,11 +854,13 @@ class KvResourceIncrImportOp: public OpKernel {
833854 << name_string
834855 << " partition_num:"
835856 << partition_num_;
857+
836858 EVRestoreDynamically (
837859 ev, name_string, partition_id_, partition_num_, context, &reader,
838860 " -incr_partition_offset" , " -sparse_incr_keys" , " -sparse_incr_values" ,
839861 " -sparse_incr_versions" , " -sparse_incr_freqs" );
840862 ev->SetInitialized ();
863+ done ();
841864 }
842865
843866 private:
0 commit comments