@@ -102,45 +102,48 @@ void CheckpointLoader<K, V>::RestoreInternal(
102102 Tensor part_filter_offset_tensor;
103103 if (!restore_args_.m_is_oldform ) {
104104 /* ***** InitPartOffsetTensor ******/
105- TensorShape part_offset_shape, part_filter_offset_shape ;
106- DataType part_offset_type, part_filter_offset_type ;
105+ TensorShape part_offset_shape;
106+ DataType part_offset_type;
107107 string offset_tensor_name;
108108 if (!restore_args_.m_is_incr ) {
109109 offset_tensor_name = name_string + kPartOffsetTensorSuffsix ;
110110 } else {
111111 offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix ;
112112 }
113-
114- string offset_filter_tensor_name =
115- name_string + kPartFilterOffsetTensorSuffsix ;
113+
116114 Status s = reader_->LookupDtypeAndShape (
117115 offset_tensor_name, &part_offset_type, &part_offset_shape);
118116 if (!s.ok ()) {
119117 LOG (ERROR) << " EV restoring fail:" << s.error_message ();
120118 }
121- s = reader_->LookupDtypeAndShape (offset_filter_tensor_name,
122- &part_filter_offset_type,
123- &part_filter_offset_shape);
124- if (!s.ok ()) {
125- LOG (ERROR) << " EV restoring fail: " << s.error_message ();
126- }
127119 part_offset_tensor =
128120 Tensor (cpu_allocator (), part_offset_type, part_offset_shape);
129- part_filter_offset_tensor = Tensor (
130- cpu_allocator (), part_filter_offset_type, part_filter_offset_shape);
131121 s = reader_->Lookup (offset_tensor_name, &part_offset_tensor);
132122 if (!s.ok ()) {
133123 LOG (ERROR) << " EV restoring fail:" << s.error_message ();
134124 }
135125
136- s = reader_->Lookup (offset_filter_tensor_name,
137- &part_filter_offset_tensor);
138- if (!s.ok ()) {
139- LOG (ERROR) << " EV restoring fail: " << s.error_message ();
126+ if (restore_args_.m_has_filter ) {
127+ TensorShape part_filter_offset_shape;
128+ DataType part_filter_offset_type;
129+ string offset_filter_tensor_name =
130+ name_string + kPartFilterOffsetTensorSuffsix ;
131+ s = reader_->LookupDtypeAndShape (offset_filter_tensor_name,
132+ &part_filter_offset_type,
133+ &part_filter_offset_shape);
134+ if (!s.ok ()) {
135+ LOG (ERROR) << " EV restoring fail: " << s.error_message ();
136+ }
137+ part_filter_offset_tensor = \
138+ Tensor (cpu_allocator (), part_filter_offset_type,
139+ part_filter_offset_shape);
140+ s = reader_->Lookup (offset_filter_tensor_name,
141+ &part_filter_offset_tensor);
142+ if (!s.ok ()) {
143+ LOG (ERROR) << " EV restoring fail: " << s.error_message ();
144+ }
140145 }
141146 }
142- auto part_offset_flat = part_offset_tensor.flat <int32>();
143- auto part_filter_offset_flat = part_filter_offset_tensor.flat <int32>();
144147
145148 if (restore_args_.m_is_oldform ) {
146149 VLOG (1 ) << " old form, EV name:" << name_string
@@ -164,6 +167,7 @@ void CheckpointLoader<K, V>::RestoreInternal(
164167 VLOG (1 ) << " new form checkpoint... :" << name_string
165168 << " , partition_id:" << restore_args_.m_partition_id
166169 << " , partition_num:" << restore_args_.m_partition_num ;
170+ auto part_offset_flat = part_offset_tensor.flat <int32>();
167171 for (size_t i = 0 ; i < restore_args_.m_loaded_parts .size (); i++) {
168172 int subpart_id = restore_args_.m_loaded_parts [i];
169173 size_t value_unit_bytes = sizeof (V) * restore_args_.m_old_dim ;
@@ -183,6 +187,7 @@ void CheckpointLoader<K, V>::RestoreInternal(
183187 new_dim, emb_config, device);
184188
185189 if (restore_args_.m_has_filter ) {
190+ auto part_filter_offset_flat = part_filter_offset_tensor.flat <int32>();
186191 Status s = EVRestoreFilteredFeatures (
187192 subpart_id, new_dim, restore_buff, part_filter_offset_flat,
188193 emb_config, device);
@@ -444,7 +449,7 @@ Status CheckpointLoader<K, V>::EVInitTensorNameAndShape(
444449 }
445450 st = reader_->LookupHeader (restore_args_.m_tensor_version + " _filtered" ,
446451 sizeof (K) * version_filter_shape.dim_size (0 ));
447- if (!st.ok ()) {
452+ if (!st.ok () && st. code () != error::NOT_FOUND ) {
448453 return st;
449454 }
450455 st = reader_->LookupTensorShape (restore_args_.m_tensor_freq + " _filtered" ,
@@ -463,7 +468,8 @@ Status CheckpointLoader<K, V>::EVInitTensorNameAndShape(
463468 return st;
464469 }
465470 }
466- return st;
471+
472+ return Status::OK ();
467473}
468474#define REGISTER_KERNELS (ktype, vtype ) \
469475 template Status CheckpointLoader<ktype, vtype>::EVInitTensorNameAndShape(\
@@ -644,4 +650,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
644650#undef REGISTER_KERNELS_ALL_INDEX
645651#undef REGISTER_KERNELS
646652
647- }// namespace tensorflow
653+ }// namespace tensorflow
0 commit comments