Skip to content

Commit 6bf5621

Browse files
authored
[Embedding] undefine EV GPU interface in CPU compile. (#956)
Signed-off-by: candy.dc <candy.dc@alibaba-inc.com>
1 parent 717f7c5 commit 6bf5621

File tree

1 file changed

+45
-46
lines changed

1 file changed

+45
-46
lines changed

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,6 @@ class EmbeddingVar : public ResourceBase {
140140
return storage_->Get(key, value_ptr);
141141
}
142142

143-
void BatchLookupKey(const EmbeddingVarContext<GPUDevice>& ctx,
144-
const K* keys,
145-
void** value_ptr_list,
146-
int64 num_of_keys) {
147-
storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys);
148-
}
149-
150143
Status LookupOrCreateKey(K key, void** value_ptr,
151144
bool* is_filter, bool indices_as_pointer,
152145
int64 count = 1) {
@@ -167,45 +160,6 @@ class EmbeddingVar : public ResourceBase {
167160
return Status::OK();
168161
}
169162

170-
Status LookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& context,
171-
const K* keys,
172-
void** value_ptrs,
173-
int64 num_of_keys,
174-
int64* indices_counts,
175-
bool indices_as_pointer = false) {
176-
if (indices_as_pointer) {
177-
auto lookup_key_and_set_version_fn = [keys, value_ptrs]
178-
(int64 start, int64 limit) {
179-
for (int i = start; i < limit; i++) {
180-
value_ptrs[i] = (void*)keys[i];
181-
}
182-
};
183-
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
184-
auto worker_threads = context.worker_threads;
185-
Shard(worker_threads->num_threads,
186-
worker_threads->workers, num_of_keys, unit_cost,
187-
lookup_key_and_set_version_fn);
188-
} else {
189-
filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys);
190-
}
191-
192-
if (indices_counts != nullptr) {
193-
auto add_freq_fn = [this, value_ptrs, indices_counts]
194-
(int64 start, int64 limit) {
195-
for (int i = start; i < limit; i++) {
196-
feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]);
197-
}
198-
};
199-
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
200-
auto worker_threads = context.worker_threads;
201-
Shard(worker_threads->num_threads,
202-
worker_threads->workers, num_of_keys, unit_cost,
203-
add_freq_fn);
204-
}
205-
return Status::OK();
206-
}
207-
208-
209163
Status LookupOrCreateKey(K key, void** value_ptr) {
210164
Status s = storage_->GetOrCreate(key, value_ptr);
211165
TF_CHECK_OK(s);
@@ -402,6 +356,51 @@ class EmbeddingVar : public ResourceBase {
402356

403357
storage_->AddToCache(keys_tensor);
404358
}
359+
360+
void BatchLookupKey(const EmbeddingVarContext<GPUDevice>& ctx,
361+
const K* keys,
362+
void** value_ptr_list,
363+
int64 num_of_keys) {
364+
storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys);
365+
}
366+
367+
Status LookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& context,
368+
const K* keys,
369+
void** value_ptrs,
370+
int64 num_of_keys,
371+
int64* indices_counts,
372+
bool indices_as_pointer = false) {
373+
if (indices_as_pointer) {
374+
auto lookup_key_and_set_version_fn = [keys, value_ptrs]
375+
(int64 start, int64 limit) {
376+
for (int i = start; i < limit; i++) {
377+
value_ptrs[i] = (void*)keys[i];
378+
}
379+
};
380+
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
381+
auto worker_threads = context.worker_threads;
382+
Shard(worker_threads->num_threads,
383+
worker_threads->workers, num_of_keys, unit_cost,
384+
lookup_key_and_set_version_fn);
385+
} else {
386+
filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys);
387+
}
388+
389+
if (indices_counts != nullptr) {
390+
auto add_freq_fn = [this, value_ptrs, indices_counts]
391+
(int64 start, int64 limit) {
392+
for (int i = start; i < limit; i++) {
393+
feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]);
394+
}
395+
};
396+
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
397+
auto worker_threads = context.worker_threads;
398+
Shard(worker_threads->num_threads,
399+
worker_threads->workers, num_of_keys, unit_cost,
400+
add_freq_fn);
401+
}
402+
return Status::OK();
403+
}
405404
#endif
406405

407406
#if GOOGLE_CUDA

0 commit comments

Comments
 (0)