Skip to content

Commit 772ca11

Browse files
lixy9474liutongxuan
authored andcommitted
[Embedding] Change init value of version in EV from 0 to -1.
Signed-off-by: lixy9474 <lxy268263@alibaba-inc.com>
1 parent dfc8b70 commit 772ca11

File tree

4 files changed

+62
-10
lines changed

4 files changed

+62
-10
lines changed

tensorflow/core/framework/embedding/value_ptr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ struct NormalHeader {
114114
memset(this, 0, sizeof(NormalHeader));
115115
meta.SetLayoutType(LayoutType::NORMAL);
116116
meta.SetHeaderSize(sizeof(NormalHeader) / sizeof(int64));
117+
SetGlobalStep(-1);
117118
}
118119

119120
inline int64 GetGlobalStep() {
@@ -157,6 +158,7 @@ struct FixedLengthHeader {
157158

158159
FixedLengthHeader() {
159160
memset(this, 0, sizeof(FixedLengthHeader));
161+
SetGlobalStep(-1);
160162
}
161163

162164
inline int64 GetGlobalStep() {

tensorflow/core/kernels/embedding_variable_ops_test.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,54 @@ TEST(KVInterfaceTest, TestDirectIoFile) {
17821782
TestReadEmbFile();
17831783
}
17841784

1785+
1786+
void InsertKey(EmbeddingVar<int64, float>* variable, int value_size) {
1787+
float *val = (float *)malloc((value_size+1)*sizeof(float));
1788+
for (int64 i = 0; i < 100000000; i++) {
1789+
variable->LookupOrCreate(20, val, nullptr);
1790+
}
1791+
LOG(INFO)<<"Finish Insert";
1792+
}
1793+
1794+
void RemoveKey(EmbeddingVar<int64, float>* variable) {
1795+
for (int64 i = 0; i < 10; i++) {
1796+
sleep(1);
1797+
variable->storage_manager()->Remove(20);
1798+
}
1799+
LOG(INFO)<<"Remove thread finish";
1800+
}
1801+
1802+
TEST(EmbeddingVariableTest, TestLookupRemoveConcurrency) {
1803+
int value_size = 10;
1804+
Tensor value(DT_FLOAT, TensorShape({value_size}));
1805+
test::FillValues<float>(&value, std::vector<float>(value_size, 10.0));
1806+
auto emb_config = EmbeddingConfig(
1807+
/*emb_index = */0, /*primary_emb_index = */0,
1808+
/*block_num = */1, /*slot_num = */0,
1809+
/*name = */"", /*steps_to_live = */0,
1810+
/*filter_freq = */2, /*max_freq = */999999,
1811+
/*l2_weight_threshold = */-1.0, /*layout = */"normal",
1812+
/*max_element_size = */0, /*false_positive_probability = */-1.0,
1813+
/*counter_type = */DT_UINT64);
1814+
auto storage_manager = new embedding::StorageManager<int64, float>(
1815+
"EmbeddingVar", embedding::StorageConfig());
1816+
auto var = new EmbeddingVar<int64, float>("EmbeddingVar",
1817+
storage_manager,
1818+
emb_config,
1819+
cpu_allocator());
1820+
1821+
var->Init(value, 1);
1822+
int thread_num = 5;
1823+
std::vector<std::thread> insert_threads(thread_num);
1824+
for (size_t i = 0 ; i < thread_num - 1; i++) {
1825+
insert_threads[i] = std::thread(InsertKey, var, value_size);
1826+
}
1827+
insert_threads[thread_num - 1] = std::thread(RemoveKey, var);
1828+
for (auto &t : insert_threads) {
1829+
t.join();
1830+
}
1831+
}
1832+
17851833
} // namespace
17861834
} // namespace embedding
17871835
} // namespace tensorflow

tensorflow/python/ops/embedding_variable_ops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def testEmbeddingVariableForExport(self):
239239
[1., 1., 1.],
240240
[1., 1., 1.],
241241
[1., 1., 1.]], fetches[1])
242-
self.assertAllEqual([0, 0, 0, 0, 0, 0], fetches[2])
242+
self.assertAllEqual([-1, -1, -1, -1, -1, -1], fetches[2])
243243
self.assertAllEqual([1, 1, 1, 1, 1, 1], fetches[3])
244244

245245
def testEmbeddingVariableForGetShape(self):
@@ -2081,7 +2081,7 @@ def testEmbeddingVariableForGetFrequencyAndVersion(self):
20812081
s, f, v = sess.run([shape, frequency, version])
20822082
self.assertAllEqual(np.array([5,3]), s)
20832083
self.assertAllEqual(np.array([3,1,2,0,2,0,1]), f)
2084-
self.assertAllEqual(np.array([2,0,1,0,2,0,2]), v)
2084+
self.assertAllEqual(np.array([2,0,1,-1,2,-1,2]), v)
20852085

20862086
def testEmbeddingVariableForInference(self):
20872087
print("testEmbeddingVariableForInference")

third_party/sparsehash_c11.patch

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
From dcd5910ddff22b2896dc5221eab766a6c6d2307d Mon Sep 17 00:00:00 2001
1+
From 2d87035732d9ae347515b8497ca7a75d3b36f8cf Mon Sep 17 00:00:00 2001
22
From: Tongxuan Liu <tongxuan.ltx@alibaba-inc.com>
33
Date: Mon, 13 Mar 2023 08:58:56 +0800
44
Subject: [PATCH] Avoid fetching nullptr when use feature filter.
@@ -8,14 +8,14 @@ Subject: [PATCH] Avoid fetching nullptr when use feature filter.
88
sparsehash/dense_hash_map_lockless | 447 ++++
99
sparsehash/dense_hash_set_lockless | 381 +++
1010
sparsehash/internal/densehashtable.h | 16 +-
11-
sparsehash/internal/densehashtable_lockless.h | 2033 +++++++++++++++++
11+
sparsehash/internal/densehashtable_lockless.h | 2035 +++++++++++++++++
1212
sparsehash/internal/hashtable-common.h | 4 +
1313
sparsehash/internal/sparsehashtable.h | 18 +-
1414
sparsehash/traits | 10 +-
1515
tests/bench_lockless.cc | 1466 ++++++++++++
1616
tests/dense_hash_map_unittests.cc | 137 +-
1717
tests/rwlock.h | 224 ++
18-
11 files changed, 4726 insertions(+), 23 deletions(-)
18+
11 files changed, 4728 insertions(+), 23 deletions(-)
1919
create mode 100644 sparsehash/dense_hash_map_lockless
2020
create mode 100644 sparsehash/dense_hash_set_lockless
2121
create mode 100644 sparsehash/internal/densehashtable_lockless.h
@@ -958,10 +958,10 @@ index e254126..3bc3c16 100644
958958
for (; dist > 0; --dist, ++f) {
959959
diff --git a/sparsehash/internal/densehashtable_lockless.h b/sparsehash/internal/densehashtable_lockless.h
960960
new file mode 100644
961-
index 0000000..64f677f
961+
index 0000000..2f8a80b
962962
--- /dev/null
963963
+++ b/sparsehash/internal/densehashtable_lockless.h
964-
@@ -0,0 +1,2033 @@
964+
@@ -0,0 +1,2035 @@
965965
+// Copyright (c) 2005, Google Inc.
966966
+// All rights reserved.
967967
+//
@@ -2204,9 +2204,11 @@ index 0000000..64f677f
22042204
+ }else if(test_deleted(bucknum, tmp_pointer)) {
22052205
+ if(insert_pos == ILLEGAL_BUCKET) insert_pos = bucknum;
22062206
+ }else if (equals(key, get_key(tmp_pointer->table_[bucknum]))) {
2207-
+ std::pair<K, T> tmp(tmp_pointer->table_[bucknum].first,
2208-
+ tmp_pointer->table_[bucknum].second);
2209-
+ if(tmp.first == key){
2207+
+ //Force to read from volatile memory
2208+
+ volatile K* key_ptr = const_cast<K*>(&(tmp_pointer->table_[bucknum].first));
2209+
+ volatile T* value_ptr = const_cast<T*>(&(tmp_pointer->table_[bucknum].second));
2210+
+ std::pair<K, T> tmp(*key_ptr, *value_ptr);
2211+
+ if(*key_ptr == key){
22102212
+ return tmp;
22112213
+ }else{
22122214
+ return std::pair<K,T>(tmp_pointer->key_info_.empty_key, empty_value);

0 commit comments

Comments
 (0)