Skip to content

Commit cf16856

Browse files
authored
[Incremental Checkpoint] Fix import incremental embedding variable. (#983)
Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
1 parent 6dae552 commit cf16856

File tree

2 files changed

+82
-22
lines changed

2 files changed

+82
-22
lines changed

tensorflow/core/framework/embedding/embedding_var_restore.cc

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,45 +102,48 @@ void CheckpointLoader<K, V>::RestoreInternal(
102102
Tensor part_filter_offset_tensor;
103103
if (!restore_args_.m_is_oldform) {
104104
/****** InitPartOffsetTensor ******/
105-
TensorShape part_offset_shape, part_filter_offset_shape;
106-
DataType part_offset_type, part_filter_offset_type;
105+
TensorShape part_offset_shape;
106+
DataType part_offset_type;
107107
string offset_tensor_name;
108108
if (!restore_args_.m_is_incr) {
109109
offset_tensor_name = name_string + kPartOffsetTensorSuffsix;
110110
} else {
111111
offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix;
112112
}
113-
114-
string offset_filter_tensor_name =
115-
name_string + kPartFilterOffsetTensorSuffsix;
113+
116114
Status s = reader_->LookupDtypeAndShape(
117115
offset_tensor_name, &part_offset_type, &part_offset_shape);
118116
if (!s.ok()) {
119117
LOG(ERROR) << "EV restoring fail:" << s.error_message();
120118
}
121-
s = reader_->LookupDtypeAndShape(offset_filter_tensor_name,
122-
&part_filter_offset_type,
123-
&part_filter_offset_shape);
124-
if (!s.ok()) {
125-
LOG(ERROR) << "EV restoring fail: " << s.error_message();
126-
}
127119
part_offset_tensor =
128120
Tensor(cpu_allocator(), part_offset_type, part_offset_shape);
129-
part_filter_offset_tensor = Tensor(
130-
cpu_allocator(), part_filter_offset_type, part_filter_offset_shape);
131121
s = reader_->Lookup(offset_tensor_name, &part_offset_tensor);
132122
if (!s.ok()) {
133123
LOG(ERROR) << "EV restoring fail:" << s.error_message();
134124
}
135125

136-
s = reader_->Lookup(offset_filter_tensor_name,
137-
&part_filter_offset_tensor);
138-
if (!s.ok()) {
139-
LOG(ERROR) << "EV restoring fail: " << s.error_message();
126+
if (restore_args_.m_has_filter) {
127+
TensorShape part_filter_offset_shape;
128+
DataType part_filter_offset_type;
129+
string offset_filter_tensor_name =
130+
name_string + kPartFilterOffsetTensorSuffsix;
131+
s = reader_->LookupDtypeAndShape(offset_filter_tensor_name,
132+
&part_filter_offset_type,
133+
&part_filter_offset_shape);
134+
if (!s.ok()) {
135+
LOG(ERROR) << "EV restoring fail: " << s.error_message();
136+
}
137+
part_filter_offset_tensor = \
138+
Tensor(cpu_allocator(), part_filter_offset_type,
139+
part_filter_offset_shape);
140+
s = reader_->Lookup(offset_filter_tensor_name,
141+
&part_filter_offset_tensor);
142+
if (!s.ok()) {
143+
LOG(ERROR) << "EV restoring fail: " << s.error_message();
144+
}
140145
}
141146
}
142-
auto part_offset_flat = part_offset_tensor.flat<int32>();
143-
auto part_filter_offset_flat = part_filter_offset_tensor.flat<int32>();
144147

145148
if (restore_args_.m_is_oldform) {
146149
VLOG(1) << "old form, EV name:" << name_string
@@ -164,6 +167,7 @@ void CheckpointLoader<K, V>::RestoreInternal(
164167
VLOG(1) << "new form checkpoint... :" << name_string
165168
<< " , partition_id:" << restore_args_.m_partition_id
166169
<< " , partition_num:" << restore_args_.m_partition_num;
170+
auto part_offset_flat = part_offset_tensor.flat<int32>();
167171
for (size_t i = 0; i < restore_args_.m_loaded_parts.size(); i++) {
168172
int subpart_id = restore_args_.m_loaded_parts[i];
169173
size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim;
@@ -183,6 +187,7 @@ void CheckpointLoader<K, V>::RestoreInternal(
183187
new_dim, emb_config, device);
184188

185189
if (restore_args_.m_has_filter) {
190+
auto part_filter_offset_flat = part_filter_offset_tensor.flat<int32>();
186191
Status s = EVRestoreFilteredFeatures(
187192
subpart_id, new_dim, restore_buff, part_filter_offset_flat,
188193
emb_config, device);
@@ -444,7 +449,7 @@ Status CheckpointLoader<K, V>::EVInitTensorNameAndShape(
444449
}
445450
st = reader_->LookupHeader(restore_args_.m_tensor_version + "_filtered",
446451
sizeof(K) * version_filter_shape.dim_size(0));
447-
if (!st.ok()) {
452+
if (!st.ok() && st.code() != error::NOT_FOUND) {
448453
return st;
449454
}
450455
st = reader_->LookupTensorShape(restore_args_.m_tensor_freq + "_filtered",
@@ -463,7 +468,8 @@ Status CheckpointLoader<K, V>::EVInitTensorNameAndShape(
463468
return st;
464469
}
465470
}
466-
return st;
471+
472+
return Status::OK();
467473
}
468474
#define REGISTER_KERNELS(ktype, vtype) \
469475
template Status CheckpointLoader<ktype, vtype>::EVInitTensorNameAndShape(\
@@ -644,4 +650,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
644650
#undef REGISTER_KERNELS_ALL_INDEX
645651
#undef REGISTER_KERNELS
646652

647-
}// namespace tensorflow
653+
}// namespace tensorflow

tensorflow/python/training/incr_ckpt_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,5 +451,59 @@ def testIncrementalSaverForResourceVariable(self):
451451
saver.build()
452452
incr_saver = incr_saver_module._get_incremental_saver(True, saver)
453453

454+
def testIncrementalSaverSaveAndRestore(self):
455+
tmp_path = self.get_temp_dir()
456+
full_ckpt_dir = os.path.join(tmp_path, "model.ckpt")
457+
incr_ckpt_dir = os.path.join(tmp_path, "incr.ckpt")
458+
full_ckpt_path = None
459+
incr_ckpt_path = None
460+
461+
# construct graph
462+
emb_var = variable_scope.get_embedding_variable("emb", embedding_dim=3,
463+
initializer = init_ops.ones_initializer(dtypes.float32))
464+
emb = embedding_ops.embedding_lookup(emb_var,
465+
math_ops.cast([0, 1, 2, 3, 4], dtypes.int64))
466+
loss = math_ops.reduce_sum(emb, name = 'reduce_sum')
467+
opt = adagrad.AdagradOptimizer(0.1)
468+
g_v = opt.compute_gradients(loss)
469+
train_op = opt.apply_gradients(g_v)
470+
init = variables.global_variables_initializer()
471+
saver = saver_module.Saver(sharded=True, incremental_save_restore=True)
472+
incr_saver = \
473+
incr_saver_module.IncrementalSaver(sharded=True,
474+
saver_def=saver.saver_def, defer_build=True)
475+
incr_saver.build(saver._builder.filename_tensor)
476+
477+
# generate full ckpt and incr ckpt.
478+
full_ckpt_value=None
479+
incr_ckpt_value=None
480+
with self.test_session() as sess:
481+
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
482+
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS))
483+
sess.run([init])
484+
sess.run([train_op])
485+
full_ckpt_path = saver.save(sess, full_ckpt_dir, global_step = 10)
486+
full_ckpt_value = sess.run([emb])
487+
print("full_ckpt: {}".format(full_ckpt_value))
488+
sess.run([train_op])
489+
incr_ckpt_path = \
490+
incr_saver.incremental_save(sess, incr_ckpt_dir, global_step=20)
491+
incr_ckpt_value = sess.run([emb])
492+
print("incr_ckpt: {}".format(incr_ckpt_value))
493+
494+
# check the value after restoring parameter.
495+
with self.test_session() as sess:
496+
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
497+
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS))
498+
sess.run([init])
499+
saver.restore(sess, full_ckpt_path)
500+
restore_full_ckpt_value = sess.run([emb])
501+
print("restore_full_ckpt: {}".format(restore_full_ckpt_value))
502+
incr_saver.incremental_restore(sess, full_ckpt_path, incr_ckpt_path)
503+
restore_incr_ckpt_value = sess.run([emb])
504+
print("restore_incr_ckpt: {}".format(restore_incr_ckpt_value))
505+
self.assertAllClose(full_ckpt_value, restore_full_ckpt_value)
506+
self.assertAllClose(incr_ckpt_value, restore_incr_ckpt_value)
507+
454508
if __name__ == "__main__":
455509
googletest.main()

0 commit comments

Comments
 (0)