Skip to content

Commit bf9d14e

Browse files
committed
[Embedding] Asynchronous restore EmbeddingVariable from checkpoint.
1 parent d09d989 commit bf9d14e

File tree

2 files changed

+125
-89
lines changed

2 files changed

+125
-89
lines changed

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "tensorflow/core/platform/mem.h"
3737
#include "tensorflow/core/platform/mutex.h"
3838
#include "tensorflow/core/platform/types.h"
39+
#include "tensorflow/core/util/env_var.h"
3940
#include "tensorflow/core/util/util.h"
4041
#include "tensorflow/core/util/work_sharder.h"
4142
#if GOOGLE_CUDA
@@ -126,36 +127,21 @@ class InitializeKvVariableOp : public OpKernel {
126127
OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
127128
OP_REQUIRES(c, shape_.dims() == 1,
128129
errors::InvalidArgument("KvVariable dimension must be 1"));
129-
130-
// get ev emb_index
131130
OP_REQUIRES_OK(c, c->GetAttr("emb_index", &emb_index_));
132-
// get ev block_num
133131
OP_REQUIRES_OK(c, c->GetAttr("block_num", &block_num_));
134-
// get ev slot_index
135132
OP_REQUIRES_OK(c, c->GetAttr("slot_index", &slot_index_));
136-
137133
OP_REQUIRES_OK(c, c->GetAttr("steps_to_live", &steps_to_live_));
138-
139134
OP_REQUIRES_OK(c, c->GetAttr("filter_freq", &filter_freq_));
140-
141135
OP_REQUIRES_OK(c, c->GetAttr("max_freq", &max_freq_));
142-
143136
OP_REQUIRES_OK(c, c->GetAttr("max_element_size", &max_element_size_));
144-
145137
OP_REQUIRES_OK(c, c->GetAttr("false_positive_probability",
146138
&false_positive_probability_));
147-
148139
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
149140
&l2_weight_threshold_));
150-
151141
OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_));
152-
153142
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim", &default_value_dim_));
154-
155143
OP_REQUIRES_OK(c, c->GetAttr("slot_num", &slot_num_));
156-
157144
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
158-
159145
OP_REQUIRES_OK(c, c->GetAttr("record_version", &record_version_));
160146

161147
int64 storage_type = 0;
@@ -592,10 +578,34 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS_ALL_INDEX);
592578
#undef REGISTER_KERNELS_ALL_INDEX
593579
#undef REGISTER_KERNELS
594580
*/
581+
582+
constexpr int64 DEFAULT_RESTORE_THREAD_NUM = 4;
583+
584+
class KvRestoreThreadPool {
585+
public:
586+
KvRestoreThreadPool() {
587+
TF_CHECK_OK(ReadInt64FromEnvVar("TF_EV_RESTORE_THREAD_NUM",
588+
DEFAULT_RESTORE_THREAD_NUM, &thread_num_));
589+
}
590+
591+
static thread::ThreadPool* GetInstance() {
592+
static thread::ThreadPool tp(Env::Default(),
593+
"restore_ev_threadpool", thread_num_);
594+
return &tp;
595+
}
596+
597+
private:
598+
static int64 thread_num_;
599+
};
600+
601+
int64 KvRestoreThreadPool::thread_num_ =
602+
DEFAULT_RESTORE_THREAD_NUM;
603+
595604
template <typename TKey, typename TValue>
596-
class KvResourceImportV2Op: public OpKernel {
605+
class KvResourceImportV2Op: public AsyncOpKernel {
597606
public:
598-
explicit KvResourceImportV2Op(OpKernelConstruction* c) : OpKernel(c) {
607+
explicit KvResourceImportV2Op(OpKernelConstruction* c)
608+
: AsyncOpKernel(c) {
599609
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
600610
OP_REQUIRES_OK(c, c->GetAttr("counter_type", &counter_type_));
601611
OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
@@ -619,15 +629,11 @@ class KvResourceImportV2Op: public OpKernel {
619629
//OP_REQUIRES_OK(c, c->GetAttr("restore_versions", &restore_versions_));
620630
OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_));
621631
OP_REQUIRES_OK(c, c->GetAttr("ht_partition_num", &ht_partition_num_));
622-
// get ev emb_index
623632
OP_REQUIRES_OK(c, c->GetAttr("emb_index", &emb_index_));
624-
// get ev slot_index
625633
OP_REQUIRES_OK(c, c->GetAttr("slot_index", &slot_index_));
626634
OP_REQUIRES_OK(c, c->GetAttr("filter_freq", &filter_freq_));
627635
OP_REQUIRES_OK(c, c->GetAttr("block_num", &block_num_));
628-
629636
OP_REQUIRES_OK(c, c->GetAttr("max_element_size", &max_element_size_));
630-
631637
OP_REQUIRES_OK(c, c->GetAttr("false_positive_probability",
632638
&false_positive_probability_));
633639
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
@@ -636,27 +642,26 @@ class KvResourceImportV2Op: public OpKernel {
636642
OP_REQUIRES_OK(c, c->GetAttr("max_freq", &max_freq_));
637643
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim",
638644
&default_value_dim_));
639-
640645
OP_REQUIRES_OK(c, c->GetAttr("slot_num", &slot_num_));
641-
642646
int64 storage_type = 0;
643647
OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type));
644648
storage_type_ = static_cast<embedding::StorageType>(storage_type);
645649

646650
OP_REQUIRES_OK(c, c->GetAttr("storage_path", &storage_path_));
647651
OP_REQUIRES_OK(c, c->GetAttr("storage_size", &storage_size_));
648-
649652
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
650-
651653
OP_REQUIRES_OK(c, c->GetAttr("record_version", &record_version_));
654+
655+
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EV_ASYNC_RESTORE", true,
656+
&ev_async_restore_));
652657
}
653658

654-
void Compute(OpKernelContext* context) override {
659+
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
655660
const Tensor& file_name = context->input(0);
656661
const std::string file_name_string = file_name.scalar<string>()();
657662
const Tensor& name = context->input(4);
658663
const std::string name_string = name.scalar<string>()();
659-
const Tensor& default_values = context->input(3);
664+
const Tensor& default_values = context->input(3);
660665
OP_REQUIRES(context, dtype_ == default_values.dtype(),
661666
errors::InvalidArgument(
662667
"Variable and ddd value dtypes don't match; respectively, ",
@@ -744,15 +749,29 @@ class KvResourceImportV2Op: public OpKernel {
744749
}
745750
core::ScopedUnref unref_me(ev);
746751

747-
BundleReader reader(Env::Default(), file_name_string);
748-
OP_REQUIRES_OK(context, reader.status());
752+
auto do_compute = [this, context, file_name_string, ev,
753+
name_string, done] () {
754+
BundleReader reader(Env::Default(), file_name_string);
755+
auto s = reader.status();
756+
if (!s.ok()) {
757+
LOG(FATAL) << "Restore EV failure, create BundleReader error:"
758+
<< s.ToString();
759+
}
749760

750-
EVRestoreDynamically(
751-
ev, name_string, partition_id_, partition_num_, context, &reader,
752-
"-partition_offset", "-keys", "-values", "-versions", "-freqs");
753-
ev->SetInitialized();
754-
}
761+
EVRestoreDynamically(
762+
ev, name_string, partition_id_, partition_num_, context, &reader,
763+
"-partition_offset", "-keys", "-values", "-versions", "-freqs");
764+
ev->SetInitialized();
765+
done();
766+
};
755767

768+
if (ev_async_restore_) {
769+
auto tp = KvRestoreThreadPool::GetInstance();
770+
tp->Schedule(do_compute);
771+
} else {
772+
do_compute();
773+
}
774+
}
756775

757776
private:
758777
int64 partition_id_;
@@ -780,6 +799,7 @@ class KvResourceImportV2Op: public OpKernel {
780799
int64 default_value_dim_;
781800
bool record_freq_;
782801
bool record_version_;
802+
bool ev_async_restore_;
783803
};
784804

785805
#define REGISTER_KERNELS(ktype, vtype) \
@@ -798,9 +818,10 @@ TF_CALL_double(REGISTER_KERNELS_ALL_INDEX);
798818
#undef REGISTER_KERNELS
799819

800820
template <typename TKey, typename TValue>
801-
class KvResourceIncrImportOp: public OpKernel {
821+
class KvResourceIncrImportOp: public AsyncOpKernel {
802822
public:
803-
explicit KvResourceIncrImportOp(OpKernelConstruction* c) : OpKernel(c) {
823+
explicit KvResourceIncrImportOp(OpKernelConstruction* c)
824+
: AsyncOpKernel(c) {
804825
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
805826

806827
OP_REQUIRES_OK(c, c->GetAttr("partition_id", &partition_id_));
@@ -814,7 +835,7 @@ class KvResourceIncrImportOp: public OpKernel {
814835

815836
}
816837

817-
void Compute(OpKernelContext* context) override {
838+
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
818839
const Tensor& file_name = context->input(0);
819840
const std::string file_name_string = file_name.scalar<string>()();
820841
const Tensor& name = context->input(2);
@@ -833,11 +854,13 @@ class KvResourceIncrImportOp: public OpKernel {
833854
<< name_string
834855
<< "partition_num:"
835856
<< partition_num_;
857+
836858
EVRestoreDynamically(
837859
ev, name_string, partition_id_, partition_num_, context, &reader,
838860
"-incr_partition_offset", "-sparse_incr_keys", "-sparse_incr_values",
839861
"-sparse_incr_versions", "-sparse_incr_freqs");
840862
ev->SetInitialized();
863+
done();
841864
}
842865

843866
private:

tensorflow/core/kernels/kv_variable_ops.h

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ class EVValueDumpIterator: public DumpIterator<T> {
9898
int64 col_idx_;
9999
};
100100

101-
102101
template<class T>
103102
class EVVersionDumpIterator: public DumpIterator<T> {
104103
public:
@@ -399,11 +398,14 @@ Status DumpEmbeddingValues(EmbeddingVar<K, V>* ev,
399398
return Status::OK();
400399
}
401400

401+
namespace {
402+
const static string part_str = "part_";
403+
}
404+
402405
template<typename K, typename V>
403406
Status DynamicRestoreValue(EmbeddingVar<K, V>* ev, BundleReader* reader,
404407
std::string name_string, int orig_partnum,
405408
int64 partition_id = 0, int64 partition_num = 1) {
406-
string part_str = "part_";
407409
string curr_partid_str = std::to_string(partition_id);
408410
bool filter_flag = true;
409411
bool restore_filter_flag = true;
@@ -522,9 +524,8 @@ Status DynamicRestoreValue(EmbeddingVar<K, V>* ev, BundleReader* reader,
522524
return Status::OK();
523525
}
524526

525-
526527
template<typename K, typename V>
527-
Status RestoreValue(EmbeddingVar<K, V>* ev, BundleReader* reader,
528+
Status EVRestoreNoPartition(EmbeddingVar<K, V>* ev, BundleReader* reader,
528529
std::string tensor_key, std::string tensor_value,
529530
std::string tensor_version, std::string tensor_freq) {
530531
TensorShape key_shape;
@@ -574,7 +575,6 @@ Status RestoreValue(EmbeddingVar<K, V>* ev, BundleReader* reader,
574575
}
575576
}
576577

577-
578578
bool filter_flag = true;
579579
bool restore_filter_flag = true;
580580
st = reader->LookupHeader(tensor_key,
@@ -701,31 +701,11 @@ Status RestoreValue(EmbeddingVar<K, V>* ev, BundleReader* reader,
701701
return Status::OK();
702702
}
703703

704-
template<typename K, typename V>
705-
Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
706-
std::string name_string, int partition_id, int partition_num,
707-
OpKernelContext* context, BundleReader* reader,
708-
std::string part_offset_tensor_suffix, std::string key_suffix,
709-
std::string value_suffix, std::string version_suffix,
710-
std::string freq_suffix) {
711-
712-
// first check whether there is partition
713-
string part_str = "part_";
714-
715-
if (name_string.find(part_str) == std::string::npos) {
716-
// no partition
717-
Status s = RestoreValue(ev, reader, name_string + key_suffix,
718-
name_string + value_suffix, name_string + version_suffix,
719-
name_string + freq_suffix);
720-
if (!s.ok()) {
721-
LOG(FATAL) << "EV restoring fail:" << s.ToString();
722-
}
723-
return s;
724-
}
725-
704+
inline bool IsOldCheckpoint(const std::string& name_string,
705+
const std::string& curr_partid_str, BundleReader* reader,
706+
const std::string& part_offset_tensor_suffix) {
726707
// then check whether checkpoint is in old form
727708
bool is_oldform = false;
728-
string curr_partid_str = std::to_string(partition_id);
729709

730710
string part_id = std::to_string(0);
731711
string pre_subname =
@@ -742,34 +722,69 @@ Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
742722
if (!form_st.ok()) {
743723
is_oldform = true;
744724
}
725+
return is_oldform;
726+
}
745727

746-
if (is_oldform) {
747-
// first get original partition number
748-
int orig_partnum = 0;
749-
for (; ; orig_partnum++) {
750-
string part_id = std::to_string(orig_partnum);
751-
string pre_subname = name_string.substr(0, name_string.find(part_str));
752-
string post_subname = name_string.substr(name_string.find(part_str)
753-
+ part_str.size() + curr_partid_str.size());
754-
string tensor_name = pre_subname + part_str + part_id + post_subname;
755-
756-
string tensor_key = tensor_name + key_suffix;
757-
TensorShape key_shape;
758-
Status st = reader->LookupTensorShape(tensor_key, &key_shape);
759-
if (!st.ok()) {
760-
break;
761-
}
728+
template<typename K, typename V>
729+
Status EVRestoreOldFromCheckpoint(EmbeddingVar<K, V>* ev,
730+
const std::string& name_string, const std::string& curr_partid_str,
731+
const std::string& key_suffix, int partition_id,
732+
BundleReader* reader, int partition_num) {
733+
// first get original partition number
734+
int orig_partnum = 0;
735+
for (; ; orig_partnum++) {
736+
string part_id = std::to_string(orig_partnum);
737+
string pre_subname = name_string.substr(0, name_string.find(part_str));
738+
string post_subname = name_string.substr(name_string.find(part_str)
739+
+ part_str.size() + curr_partid_str.size());
740+
string tensor_name = pre_subname + part_str + part_id + post_subname;
741+
742+
string tensor_key = tensor_name + key_suffix;
743+
TensorShape key_shape;
744+
Status st = reader->LookupTensorShape(tensor_key, &key_shape);
745+
if (!st.ok()) {
746+
break;
762747
}
748+
}
763749

764-
VLOG(1) << "old form, EV name:" << name_string
765-
<< ", partition_id:" << partition_id
766-
<< ", old partition_num:" << orig_partnum
767-
<< ", new partition num:" << partition_num;
768-
Status s = DynamicRestoreValue(ev, reader, name_string,
769-
orig_partnum, partition_id, partition_num);
750+
VLOG(1) << "old form, EV name:" << name_string
751+
<< ", partition_id:" << curr_partid_str
752+
<< ", old partition_num:" << orig_partnum
753+
<< ", new partition num:" << partition_num;
754+
Status s = DynamicRestoreValue(ev, reader, name_string,
755+
orig_partnum, partition_id, partition_num);
756+
if (!s.ok()) {
757+
LOG(FATAL) << "EV restoring fail:" << s.ToString();
758+
}
759+
}
760+
761+
template<typename K, typename V>
762+
Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
763+
const std::string& name_string, int partition_id,
764+
int partition_num, OpKernelContext* context,
765+
BundleReader* reader, const std::string& part_offset_tensor_suffix,
766+
const std::string& key_suffix, const std::string& value_suffix,
767+
const std::string& version_suffix, const std::string& freq_suffix) {
768+
769+
// first check whether there is partition
770+
if (name_string.find(part_str) == std::string::npos) {
771+
Status s = EVRestoreNoPartition(
772+
ev, reader, name_string + key_suffix,
773+
name_string + value_suffix, name_string + version_suffix,
774+
name_string + freq_suffix);
770775
if (!s.ok()) {
771776
LOG(FATAL) << "EV restoring fail:" << s.ToString();
772777
}
778+
return s;
779+
}
780+
781+
const string& curr_partid_str = std::to_string(partition_id);
782+
auto is_oldform = IsOldCheckpoint(name_string, curr_partid_str,
783+
reader, part_offset_tensor_suffix);
784+
785+
if (is_oldform) {
786+
EVRestoreOldFromCheckpoint(ev, name_string, curr_partid_str, key_suffix,
787+
partition_id, reader, partition_num);
773788
} else {
774789
// first find out which sub parts we should load
775790
bool filter_flag = true;
@@ -917,7 +932,6 @@ Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
917932
if (!st.ok()) {
918933
LOG(FATAL) << "EV restoring fail:" << st.ToString();
919934
}
920-
921935
Tensor part_offset_tensor;
922936
st = context->allocate_temp(part_offset_type,
923937
part_offset_shape, &part_offset_tensor);
@@ -930,7 +944,6 @@ Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
930944
if (!st.ok()) {
931945
LOG(FATAL) << "EV restoring fail:" << st.ToString();
932946
}
933-
934947
st = reader->Lookup(offset_tensor_name, &part_offset_tensor);
935948
if (!st.ok()) {
936949
LOG(FATAL) << "EV restoring fail:" << st.ToString();

0 commit comments

Comments
 (0)