Skip to content

Commit e96c8e4

Browse files
committed
Fix for CPU
1 parent 30b04e9 commit e96c8e4

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

train.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -815,12 +815,21 @@ def build_model():
815815
return model
816816

817817

818+
def _load(checkpoint_path):
819+
if use_cuda:
820+
checkpoint = torch.load(checkpoint_path)
821+
else:
822+
checkpoint = torch.load(checkpoint_path,
823+
map_location=lambda storage, loc: storage)
824+
return checkpoint
825+
826+
818827
def load_checkpoint(path, model, optimizer, reset_optimizer):
819828
global global_step
820829
global global_epoch
821830

822831
print("Load checkpoint from: {}".format(path))
823-
checkpoint = torch.load(path)
832+
checkpoint = _load(path)
824833
model.load_state_dict(checkpoint["state_dict"])
825834
if not reset_optimizer:
826835
optimizer_state = checkpoint["optimizer"]
@@ -834,15 +843,15 @@ def load_checkpoint(path, model, optimizer, reset_optimizer):
834843

835844

836845
def _load_embedding(path, model):
837-
state = torch.load(path)["state_dict"]
846+
state = _load(path)["state_dict"]
838847
key = "seq2seq.encoder.embed_tokens.weight"
839848
model.seq2seq.encoder.embed_tokens.weight.data = state[key]
840849

841850

842851
# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
843852
def restore_parts(path, model):
844853
print("Restore part of the model from: {}".format(path))
845-
state = torch.load(path)["state_dict"]
854+
state = _load(path)["state_dict"]
846855
model_dict = model.state_dict()
847856
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
848857
model_dict.update(valid_state_dict)
@@ -951,7 +960,8 @@ def restore_parts(path, model):
951960
# Setup summary writer for tensorboard
952961
if log_event_path is None:
953962
if platform.system() == "Windows":
954-
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_").replace(":","_")
963+
log_event_path = "log/run-test" + \
964+
str(datetime.now()).replace(" ", "_").replace(":", "_")
955965
else:
956966
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_")
957967
print("Los event path: {}".format(log_event_path))

0 commit comments

Comments
 (0)