Skip to content

Commit c0ae904

Browse files
committed
simplify documentation/error messages, allow state dict flexible load
1 parent 90cc9e9 commit c0ae904

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

models/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Models directory
22

3-
We will provide PyTorch modules of different model architectures here.
3+
We will provide example PyTorch modules of different model architectures here.
44

55
If you have any particular architecture requests, please let us know. Thanks!

selene_sdk/utils/utils.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from collections import OrderedDict
88
import logging
99
import sys
10-
import traceback
1110

1211
import numpy as np
1312

@@ -99,32 +98,30 @@ def load_model_from_state_dict(state_dict, model):
9998
model_keys = model.state_dict().keys()
10099
state_dict_keys = state_dict.keys()
101100

102-
new_state_dict = OrderedDict()
103-
104101
if len(model_keys) != len(state_dict_keys):
105-
raise ValueError("State dict does not have the same "
106-
"number of modules as the specified model "
107-
"architecture. Please check whether you are using "
108-
"the expected model architecture and that your PyTorch "
109-
"version matches the version in which the loaded model "
110-
"was trained.\n\n"
111-
"\tExpected modules:\n\t{0}\n\n"
112-
"\tModules in the loaded model weights:\n\t{1}\n".format(
113-
model_keys, state_dict_keys))
102+
try:
103+
model.load_state_dict(state_dict, strict=False)
104+
return model
105+
except Exception as e:
106+
raise ValueError("Loaded state dict does not match the model "
107+
"architecture specified - please check that you are "
108+
"using the correct architecture file and parameters.\n\n"
109+
"{0}".format(e))
114110

111+
new_state_dict = OrderedDict()
115112
for (k1, k2) in zip(model_keys, state_dict_keys):
116113
value = state_dict[k2]
117114
try:
118115
new_state_dict[k1] = value
119-
except Exception:
116+
except Exception as e:
120117
raise ValueError(
121118
"Failed to load weight from module {0} in model weights "
122-
"into model architecture module {1}. (If module name "
119+
"into model architecture module {1}. (If module name has "
123120
"an additional prefix `model.` it is because the model is "
124121
"wrapped in `selene_sdk.utils.NonStrandSpecific`. This "
125122
"error was raised because the underlying module does "
126123
"not match that expected by the loaded model:\n"
127-
"{2}".format(k2, k1, traceback.print_exc()))
124+
"{2}".format(k2, k1, e))
128125
model.load_state_dict(new_state_dict)
129126
return model
130127

0 commit comments

Comments
 (0)