@@ -52,6 +52,7 @@ using GPUDevice = Eigen::GpuDevice;
5252namespace {
5353const int64 kEmbeddingVarUseDB = -214 ;
5454const int64 kInitializableEmbeddingVarUseDB = -215 ;
55+ const char * kInferenceMode = " INFERENCE_MODE" ;
5556}
5657
5758#define REGISTER_KV_VAR_HANDLE (ktype, vtype ) \
@@ -370,6 +371,10 @@ template <typename TKey, typename TValue>
370371class KvResourceGatherOp : public OpKernel {
371372 public:
372373 explicit KvResourceGatherOp (OpKernelConstruction* c) : OpKernel(c) {
374+ OP_REQUIRES_OK (c, c->GetAttr (" is_inference" , &is_inference_));
375+ bool is_inference;
376+ TF_CHECK_OK (ReadBoolFromEnvVar (kInferenceMode , false , &is_inference));
377+ is_inference_ |= is_inference;
373378 OP_REQUIRES_OK (c,
374379 c->GetAttr (" is_use_default_value_tensor" ,
375380 &is_use_default_value_tensor_));
@@ -393,6 +398,17 @@ class KvResourceGatherOp : public OpKernel {
393398 return 1 ;
394399 };
395400 }
401+ if (!is_inference_) {
402+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
403+ TValue* val, TValue* default_v, int count) {
404+ ev->LookupOrCreate (key, val, default_v, count);
405+ };
406+ } else {
407+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
408+ TValue* val, TValue* default_v, int count) {
409+ ev->Lookup (key, val, default_v);
410+ };
411+ }
396412 }
397413
398414 void Compute (OpKernelContext* c) override {
@@ -443,7 +459,7 @@ class KvResourceGatherOp : public OpKernel {
443459 default_v, indices_flat (i), i, ev->GetDefaultValueDim (),
444460 ev->ValueLen ());
445461 int32 count = get_count_fn_ (counts, i);
446- ev-> LookupOrCreate ( indices_flat (i),
462+ lookup_fn_ (ev, indices_flat (i),
447463 out_base + i * slice_elems, default_v_ptr, count);
448464 }
449465 };
@@ -463,9 +479,12 @@ class KvResourceGatherOp : public OpKernel {
463479
464480 private:
465481 bool is_use_default_value_tensor_;
482+ bool is_inference_;
466483 std::function<
467484 TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
468485 std::function<int32(int32*, int64)> get_count_fn_;
486+ std::function<void (EmbeddingVar<TKey, TValue>* ev,
487+ TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
469488};
470489
471490#define REGISTER_GATHER_FULL (dev, ktype, vtype ) \
0 commit comments