|
7 | 7 | from collections import OrderedDict |
8 | 8 | import logging |
9 | 9 | import sys |
10 | | -import traceback |
11 | 10 |
|
12 | 11 | import numpy as np |
13 | 12 |
|
@@ -99,32 +98,30 @@ def load_model_from_state_dict(state_dict, model): |
99 | 98 | model_keys = model.state_dict().keys() |
100 | 99 | state_dict_keys = state_dict.keys() |
101 | 100 |
|
102 | | - new_state_dict = OrderedDict() |
103 | | - |
104 | 101 | if len(model_keys) != len(state_dict_keys): |
105 | | - raise ValueError("State dict does not have the same " |
106 | | - "number of modules as the specified model " |
107 | | - "architecture. Please check whether you are using " |
108 | | - "the expected model architecture and that your PyTorch " |
109 | | - "version matches the version in which the loaded model " |
110 | | - "was trained.\n\n" |
111 | | - "\tExpected modules:\n\t{0}\n\n" |
112 | | - "\tModules in the loaded model weights:\n\t{1}\n".format( |
113 | | - model_keys, state_dict_keys)) |
| 102 | + try: |
| 103 | + model.load_state_dict(state_dict, strict=False) |
| 104 | + return model |
| 105 | + except Exception as e: |
| 106 | + raise ValueError("Loaded state dict does not match the model " |
| 107 | + "architecture specified - please check that you are " |
| 108 | + "using the correct architecture file and parameters.\n\n" |
| 109 | + "{0}".format(e)) |
114 | 110 |
|
| 111 | + new_state_dict = OrderedDict() |
115 | 112 | for (k1, k2) in zip(model_keys, state_dict_keys): |
116 | 113 | value = state_dict[k2] |
117 | 114 | try: |
118 | 115 | new_state_dict[k1] = value |
119 | | - except Exception: |
| 116 | + except Exception as e: |
120 | 117 | raise ValueError( |
121 | 118 | "Failed to load weight from module {0} in model weights " |
122 | | - "into model architecture module {1}. (If module name " |
| 119 | + "into model architecture module {1}. (If module name has " |
123 | 120 | "an additional prefix `model.` it is because the model is " |
124 | 121 | "wrapped in `selene_sdk.utils.NonStrandSpecific`. This " |
125 | 122 | "error was raised because the underlying module does " |
126 | 123 | "not match that expected by the loaded model:\n" |
127 | | - "{2}".format(k2, k1, traceback.print_exc())) |
| 124 | + "{2}".format(k2, k1, e)) |
128 | 125 | model.load_state_dict(new_state_dict) |
129 | 126 | return model |
130 | 127 |
|
|
0 commit comments