Skip to content

Commit 9d9595a

Browse files
committed
train: load only valid weights
1 parent a499a1a commit 9d9595a

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

train.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,21 @@ def restore_parts(path, model):
832832
state = torch.load(path)["state_dict"]
833833
model_dict = model.state_dict()
834834
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
835-
model_dict.update(valid_state_dict)
836-
model.load_state_dict(model_dict)
835+
836+
try:
837+
model_dict.update(valid_state_dict)
838+
model.load_state_dict(model_dict)
839+
except RuntimeError as e:
840+
# there should be invalid size of weight(s), so load them per parameter
841+
print(str(e))
842+
model_dict = model.state_dict()
843+
for k, v in valid_state_dict.items():
844+
model_dict[k] = v
845+
try:
846+
model.load_state_dict(model_dict)
847+
except RuntimeError as e:
848+
print(str(e))
849+
warn("{}: may contain invalid size of weight. skipping...".format(k))
837850

838851

839852
if __name__ == "__main__":

0 commit comments

Comments
 (0)