Skip to content

Commit 7ce8477

Browse files
authored
[Embedding] Refine KVInterface::GetShardedSnapshot API. (#953)
Signed-off-by: 泊霆 <hujunqi.hjq@alibaba-inc.com>
1 parent d814969 commit 7ce8477

File tree

10 files changed

+37
-24
lines changed

10 files changed

+37
-24
lines changed

tensorflow/core/framework/embedding/cpu_hash_map_kv.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class LocklessHashMap : public KVInterface<K, V> {
138138
}
139139

140140
Status GetShardedSnapshot(
141-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
141+
std::vector<std::vector<K>>& key_list,
142+
std::vector<std::vector<void*>>& value_ptr_list,
142143
int partition_id, int partition_nums) override {
143144
std::pair<const K, void*> *hash_map_dump;
144145
int64 bucket_count;
@@ -147,11 +148,12 @@ class LocklessHashMap : public KVInterface<K, V> {
147148
bucket_count = it.second;
148149
for (int64 j = 0; j < bucket_count; j++) {
149150
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
150-
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
151-
&& hash_map_dump[j].first % kSavedPartitionNum
152-
% partition_nums != partition_id) {
153-
key_list->emplace_back(hash_map_dump[j].first);
154-
value_ptr_list->emplace_back(hash_map_dump[j].second);
151+
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_) {
152+
int part_id = hash_map_dump[j].first % kSavedPartitionNum % partition_nums;
153+
if (part_id != partition_id) {
154+
key_list[part_id].emplace_back(hash_map_dump[j].first);
155+
value_ptr_list[part_id].emplace_back(hash_map_dump[j].second);
156+
}
155157
}
156158
}
157159

tensorflow/core/framework/embedding/dense_hash_map_kv.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ class DenseHashMap : public KVInterface<K, V> {
122122
}
123123

124124
Status GetShardedSnapshot(
125-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
125+
std::vector<std::vector<K>>& key_list,
126+
std::vector<std::vector<void*>>& value_ptr_list,
126127
int partition_id, int partition_nums) override {
127128
dense_hash_map hash_map_dump[partition_num_];
128129
for (int i = 0; i< partition_num_; i++) {
@@ -131,9 +132,10 @@ class DenseHashMap : public KVInterface<K, V> {
131132
}
132133
for (int i = 0; i< partition_num_; i++) {
133134
for (const auto it : hash_map_dump[i].hash_map) {
134-
if (it.first % kSavedPartitionNum % partition_nums != partition_id) {
135-
key_list->push_back(it.first);
136-
value_ptr_list->push_back(it.second);
135+
int part_id = it.first % kSavedPartitionNum % partition_nums;
136+
if (part_id != partition_id) {
137+
key_list[part_id].emplace_back(it.first);
138+
value_ptr_list[part_id].emplace_back(it.second);
137139
}
138140
}
139141
}

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ class EmbeddingVar : public ResourceBase {
520520
}
521521
}
522522

523-
Status GetShardedSnapshot(std::vector<K>* key_list,
524-
std::vector<void*>* value_ptr_list,
523+
Status GetShardedSnapshot(std::vector<std::vector<K>>& key_list,
524+
std::vector<std::vector<void*>>& value_ptr_list,
525525
int partition_id, int partition_num) {
526526
return storage_->GetShardedSnapshot(key_list, value_ptr_list,
527527
partition_id, partition_num);
@@ -546,7 +546,7 @@ class EmbeddingVar : public ResourceBase {
546546
bool is_admit = feat_desc_->IsAdmit(value_ptr);
547547
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);
548548

549-
if (!is_admit) {
549+
if (is_admit) {
550550
key_list[i] = tot_keys_list[i];
551551

552552
if (!is_in_dram) {
@@ -571,7 +571,7 @@ class EmbeddingVar : public ResourceBase {
571571
}
572572
} else {
573573
if (!save_unfiltered_features)
574-
return;
574+
continue;
575575
//TODO(JUNQI) : currently not export filtered keys
576576
}
577577

@@ -584,6 +584,7 @@ class EmbeddingVar : public ResourceBase {
584584
feat_desc_->Deallocate(value_ptr);
585585
}
586586
}
587+
return;
587588
}
588589

589590
Status RestoreFromKeysAndValues(int64 key_num, int partition_id,

tensorflow/core/framework/embedding/gpu_hash_map_kv.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ class GPUHashMapKV : public KVInterface<K, V> {
253253
}
254254

255255
Status GetShardedSnapshot(
256-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
256+
std::vector<std::vector<K>>& key_list,
257+
std::vector<std::vector<void*>>& value_ptr_list,
257258
int partition_id, int partition_nums) override {
258259
LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot";
259260
return Status::OK();

tensorflow/core/framework/embedding/kv_interface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ class KVInterface {
9191
std::vector<void*>* value_ptr_list) = 0;
9292

9393
virtual Status GetShardedSnapshot(
94-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
94+
std::vector<std::vector<K>>& key_list,
95+
std::vector<std::vector<void*>>& value_ptr_list,
9596
int partition_id, int partition_nums) = 0;
9697

9798
virtual std::string DebugString() const = 0;

tensorflow/core/framework/embedding/leveldb_kv.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ class LevelDBKV : public KVInterface<K, V> {
194194
}
195195

196196
Status GetShardedSnapshot(
197-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
197+
std::vector<std::vector<K>>& key_list,
198+
std::vector<std::vector<void*>>& value_ptr_list,
198199
int partition_id, int partition_nums) override {
199200
ReadOptions options;
200201
options.snapshot = db_->GetSnapshot();
@@ -203,8 +204,9 @@ class LevelDBKV : public KVInterface<K, V> {
203204
for (it->SeekToFirst(); it->Valid(); it->Next()) {
204205
K key;
205206
memcpy((char*)&key, it->key().ToString().data(), sizeof(K));
206-
if (key % kSavedPartitionNum % partition_nums == partition_id) continue;
207-
key_list->emplace_back(key);
207+
int part_id = key % kSavedPartitionNum % partition_nums;
208+
if (part_id == partition_id) continue;
209+
key_list[part_id].emplace_back(key);
208210
FeatureDescriptor<V> hbm_feat_desc(
209211
1, 1, ev_allocator()/*useless*/,
210212
StorageType::HBM_DRAM, true, true,
@@ -218,7 +220,7 @@ class LevelDBKV : public KVInterface<K, V> {
218220
value_ptr, feat_desc_->GetFreq(dram_value_ptr));
219221
hbm_feat_desc.UpdateVersion(
220222
value_ptr, feat_desc_->GetVersion(dram_value_ptr));
221-
value_ptr_list->emplace_back(value_ptr);
223+
value_ptr_list[part_id].emplace_back(value_ptr);
222224
}
223225
delete it;
224226
feat_desc_->Deallocate(dram_value_ptr);

tensorflow/core/framework/embedding/multi_tier_storage.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ class MultiTierStorage : public Storage<K, V> {
9191
}
9292

9393
Status GetShardedSnapshot(
94-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
94+
std::vector<std::vector<K>>& key_list,
95+
std::vector<std::vector<void*>>& value_ptr_list,
9596
int partition_id, int partition_nums) override {
9697
LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage.";
9798
return Status::OK();

tensorflow/core/framework/embedding/single_tier_storage.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ class SingleTierStorage : public Storage<K, V> {
224224
}
225225

226226
Status GetShardedSnapshot(
227-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
227+
std::vector<std::vector<K>>& key_list,
228+
std::vector<std::vector<void*>>& value_ptr_list,
228229
int partition_id, int partition_nums) override {
229230
mutex_lock l(Storage<K, V>::mu_);
230231
return kv_->GetShardedSnapshot(key_list, value_ptr_list,

tensorflow/core/framework/embedding/ssd_hash_kv.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ class SSDHashKV : public KVInterface<K, V> {
350350
}
351351

352352
Status GetShardedSnapshot(
353-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
353+
std::vector<std::vector<K>>& key_list,
354+
std::vector<std::vector<void*>>& value_ptr_list,
354355
int partition_id, int partition_nums) override {
355356
return Status::OK();
356357
}

tensorflow/core/framework/embedding/storage.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ class Storage {
9696
virtual Status GetSnapshot(std::vector<K>* key_list,
9797
std::vector<void*>* value_ptr_list) = 0;
9898
virtual Status GetShardedSnapshot(
99-
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
99+
std::vector<std::vector<K>>& key_list,
100+
std::vector<std::vector<void*>>& value_ptr_list,
100101
int partition_id, int partition_nums) = 0;
101102
virtual Status Save(
102103
const string& tensor_name,

0 commit comments

Comments
 (0)