Skip to content

Commit d806b3f

Browse files
authored
[Embedding] Fix BatchCache coredump in background thread. (#294)
1 parent 4520669 commit d806b3f

File tree

5 files changed

+23
-10
lines changed

5 files changed

+23
-10
lines changed

tensorflow/core/framework/embedding/cache.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <unordered_map>
66
#include <set>
77
#include <list>
8+
#include "tensorflow/core/framework/tensor.h"
89
#include "tensorflow/core/platform/types.h"
910
#include "tensorflow/core/platform/mutex.h"
1011
#include "tensorflow/core/lib/core/status.h"
@@ -17,6 +18,9 @@ template <class K>
1718
class BatchCache {
1819
public:
1920
BatchCache() {}
21+
void add_to_rank(const Tensor& t) {
22+
add_to_rank((K*)t.data(), t.NumElements());
23+
}
2024
virtual size_t get_evic_ids(K* evic_ids, size_t k_size) = 0;
2125
virtual void add_to_rank(const K* batch_ids, size_t batch_size) = 0;
2226
virtual size_t size() = 0;

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ class EmbeddingVar : public ResourceBase {
163163
return storage_manager_->Size();
164164
}
165165

166+
int64 CacheSize() const {
167+
return storage_manager_->CacheSize();
168+
}
169+
166170
int64 MinFreq() {
167171
return emb_config_.filter_freq;
168172
}

tensorflow/core/framework/embedding/multilevel_embedding.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ class StorageManager {
267267
return total_size;
268268
}
269269

270+
int64 CacheSize() const {
271+
return cache_capacity_;
272+
}
273+
270274
Status GetSnapshot(std::vector<K>* key_list,
271275
std::vector<ValuePtr<V>* >* value_ptr_list) {
272276
for (auto kv : kvs_) {
@@ -375,7 +379,6 @@ class StorageManager {
375379
Status Destroy() {
376380
if (eviction_thread_) {
377381
mutex_lock l(mu_);
378-
shutdown_cv_.notify_all();
379382
shutdown_ = true;
380383
}
381384
delete eviction_thread_;
@@ -432,9 +435,7 @@ class StorageManager {
432435
if (shutdown_) {
433436
break;
434437
}
435-
const int kTimeoutMilliseconds = 1;
436-
WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
437-
438+
// add WaitForMilliseconds() for sleep if necessary
438439
for (int i = 0; i < value_ptr_out_of_date_.size(); i++) {
439440
value_ptr_out_of_date_[i]->Destroy(kvs_[0].second);
440441
delete value_ptr_out_of_date_[i];
@@ -478,10 +479,9 @@ class StorageManager {
478479
BatchCache<K>* cache_;
479480
int64 cache_capacity_;
480481
mutex mu_;
481-
condition_variable shutdown_cv_;
482-
bool shutdown_ GUARDED_BY(mu_) = false;
482+
volatile bool shutdown_ GUARDED_BY(mu_) = false;
483483

484-
bool done_ = false;
484+
volatile bool done_ = false;
485485
std::atomic_flag flag_ = ATOMIC_FLAG_INIT;
486486

487487
};

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ class KvResourceGatherOp : public OpKernel {
418418
errors::InvalidArgument(
419419
"ev's value_len should same with output's dimension(1)",
420420
std::to_string(slice_elems), std::to_string(ev->ValueLen())));
421+
OP_REQUIRES(c, !ev->IsMultiLevel() || (ev->IsMultiLevel() && ev->CacheSize() >= N),
422+
errors::InvalidArgument(
423+
"MultiLevel EV's Cache size ", ev->CacheSize(),
424+
" should large than IDs in batch ", N));
421425
const size_t slice_bytes = slice_elems * sizeof(TValue);
422426
auto do_work = [this, indices_flat,
423427
out_base, slice_elems, c, default_v, ev, counts] (
@@ -436,10 +440,10 @@ class KvResourceGatherOp : public OpKernel {
436440
worker_threads->workers, indices_size,
437441
slice_bytes, do_work);
438442

439-
ev->storage_manager()->Schedule([ev, indices_flat, indices_size]() {
443+
ev->storage_manager()->Schedule([ev, indices]() {
440444
embedding::BatchCache<TKey>* cache = ev->Cache();
441445
if (cache) {
442-
cache->add_to_rank(indices_flat.data(), indices_size);
446+
cache->add_to_rank(indices);
443447
}
444448
});
445449
}

tensorflow/python/ops/embedding_variable_ops_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,7 @@ def runTestAdagrad(self, var, g):
19651965
sess.run([init])
19661966
for i in xrange(60):
19671967
r, _, _ = sess.run([emb, train_op, loss])
1968+
r = sess.run(emb)
19681969
return r
19691970

19701971
with ops.Graph().as_default() as g:
@@ -1976,7 +1977,7 @@ def runTestAdagrad(self, var, g):
19761977
steps_to_live=5,
19771978
ev_option = variables.EmbeddingVariableOption(storage_option=variables.StorageOption(storage_type=config_pb2.StorageType.DRAM_SSDHASH,
19781979
storage_path="/tmp/ssd_utpy",
1979-
storage_size=[512])))
1980+
storage_size=[5120])))
19801981
emb1 = runTestAdagrad(self, emb_var, g)
19811982

19821983
with ops.Graph().as_default() as g:

0 commit comments

Comments
 (0)