@@ -140,13 +140,6 @@ class EmbeddingVar : public ResourceBase {
140140 return storage_->Get (key, value_ptr);
141141 }
142142
143- void BatchLookupKey (const EmbeddingVarContext<GPUDevice>& ctx,
144- const K* keys,
145- void ** value_ptr_list,
146- int64 num_of_keys) {
147- storage_->BatchGet (ctx, keys, value_ptr_list, num_of_keys);
148- }
149-
150143 Status LookupOrCreateKey (K key, void ** value_ptr,
151144 bool * is_filter, bool indices_as_pointer,
152145 int64 count = 1 ) {
@@ -167,45 +160,6 @@ class EmbeddingVar : public ResourceBase {
167160 return Status::OK ();
168161 }
169162
170- Status LookupOrCreateKey (const EmbeddingVarContext<GPUDevice>& context,
171- const K* keys,
172- void ** value_ptrs,
173- int64 num_of_keys,
174- int64* indices_counts,
175- bool indices_as_pointer = false ) {
176- if (indices_as_pointer) {
177- auto lookup_key_and_set_version_fn = [keys, value_ptrs]
178- (int64 start, int64 limit) {
179- for (int i = start; i < limit; i++) {
180- value_ptrs[i] = (void *)keys[i];
181- }
182- };
183- const int64 unit_cost = 1000 ; // very unreliable estimate for cost per step.
184- auto worker_threads = context.worker_threads ;
185- Shard (worker_threads->num_threads ,
186- worker_threads->workers , num_of_keys, unit_cost,
187- lookup_key_and_set_version_fn);
188- } else {
189- filter_->BatchLookupOrCreateKey (context, keys, value_ptrs, num_of_keys);
190- }
191-
192- if (indices_counts != nullptr ) {
193- auto add_freq_fn = [this , value_ptrs, indices_counts]
194- (int64 start, int64 limit) {
195- for (int i = start; i < limit; i++) {
196- feat_desc_->AddFreq (value_ptrs[i], indices_counts[i]);
197- }
198- };
199- const int64 unit_cost = 1000 ; // very unreliable estimate for cost per step.
200- auto worker_threads = context.worker_threads ;
201- Shard (worker_threads->num_threads ,
202- worker_threads->workers , num_of_keys, unit_cost,
203- add_freq_fn);
204- }
205- return Status::OK ();
206- }
207-
208-
209163 Status LookupOrCreateKey (K key, void ** value_ptr) {
210164 Status s = storage_->GetOrCreate (key, value_ptr);
211165 TF_CHECK_OK (s);
@@ -402,6 +356,51 @@ class EmbeddingVar : public ResourceBase {
402356
403357 storage_->AddToCache (keys_tensor);
404358 }
359+
360+ void BatchLookupKey (const EmbeddingVarContext<GPUDevice>& ctx,
361+ const K* keys,
362+ void ** value_ptr_list,
363+ int64 num_of_keys) {
364+ storage_->BatchGet (ctx, keys, value_ptr_list, num_of_keys);
365+ }
366+
367+ Status LookupOrCreateKey (const EmbeddingVarContext<GPUDevice>& context,
368+ const K* keys,
369+ void ** value_ptrs,
370+ int64 num_of_keys,
371+ int64* indices_counts,
372+ bool indices_as_pointer = false ) {
373+ if (indices_as_pointer) {
374+ auto lookup_key_and_set_version_fn = [keys, value_ptrs]
375+ (int64 start, int64 limit) {
376+ for (int i = start; i < limit; i++) {
377+ value_ptrs[i] = (void *)keys[i];
378+ }
379+ };
380+ const int64 unit_cost = 1000 ; // very unreliable estimate for cost per step.
381+ auto worker_threads = context.worker_threads ;
382+ Shard (worker_threads->num_threads ,
383+ worker_threads->workers , num_of_keys, unit_cost,
384+ lookup_key_and_set_version_fn);
385+ } else {
386+ filter_->BatchLookupOrCreateKey (context, keys, value_ptrs, num_of_keys);
387+ }
388+
389+ if (indices_counts != nullptr ) {
390+ auto add_freq_fn = [this , value_ptrs, indices_counts]
391+ (int64 start, int64 limit) {
392+ for (int i = start; i < limit; i++) {
393+ feat_desc_->AddFreq (value_ptrs[i], indices_counts[i]);
394+ }
395+ };
396+ const int64 unit_cost = 1000 ; // very unreliable estimate for cost per step.
397+ auto worker_threads = context.worker_threads ;
398+ Shard (worker_threads->num_threads ,
399+ worker_threads->workers , num_of_keys, unit_cost,
400+ add_freq_fn);
401+ }
402+ return Status::OK ();
403+ }
405404#endif
406405
407406#if GOOGLE_CUDA
0 commit comments