Skip to content

Commit 404d0fd

Browse files
committed
train: load only valid weights
1 parent 9455501 commit 404d0fd

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

train.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,21 @@ def restore_parts(path, model):
823823
state = torch.load(path)["state_dict"]
824824
model_dict = model.state_dict()
825825
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
826-
model_dict.update(valid_state_dict)
827-
model.load_state_dict(model_dict)
826+
827+
try:
828+
model_dict.update(valid_state_dict)
829+
model.load_state_dict(model_dict)
830+
except RuntimeError as e:
831+
# there should be invalid size of weight(s), so load them per parameter
832+
print(str(e))
833+
model_dict = model.state_dict()
834+
for k, v in valid_state_dict.items():
835+
model_dict[k] = v
836+
try:
837+
model.load_state_dict(model_dict)
838+
except RuntimeError as e:
839+
print(str(e))
840+
warn("{}: may contain invalid size of weight. skipping...".format(k))
828841

829842

830843
if __name__ == "__main__":

0 commit comments

Comments
 (0)