Skip to content

Commit 3bdf1e2

Browse files
committed
more robust loading model
1 parent 94aca7e commit 3bdf1e2

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

train.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,15 +847,29 @@ def _load_embedding(path, model):
847847
key = "seq2seq.encoder.embed_tokens.weight"
848848
model.seq2seq.encoder.embed_tokens.weight.data = state[key]
849849

850-
851850
# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
851+
852+
852853
def restore_parts(path, model):
853854
print("Restore part of the model from: {}".format(path))
854855
state = _load(path)["state_dict"]
855856
model_dict = model.state_dict()
856857
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
857-
model_dict.update(valid_state_dict)
858-
model.load_state_dict(model_dict)
858+
859+
try:
860+
model_dict.update(valid_state_dict)
861+
model.load_state_dict(model_dict)
862+
except RuntimeError as e:
863+
# there should be invalid size of weight(s), so load them per parameter
864+
print(str(e))
865+
model_dict = model.state_dict()
866+
for k, v in valid_state_dict.items():
867+
model_dict[k] = v
868+
try:
869+
model.load_state_dict(model_dict)
870+
except RuntimeError as e:
871+
print(str(e))
872+
warn("{}: may contain invalid size of weight. skipping...".format(k))
859873

860874

861875
if __name__ == "__main__":

0 commit comments

Comments
 (0)