Skip to content

Commit 0e2a2d8

Browse files
author
rbodo
committed
Pytorch model loading now warns about mismatch in statedicts and falls back on strict=False instead of breaking.
1 parent 6cb8a18 commit 0e2a2d8

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

snntoolbox/parsing/model_libs/pytorch_input_lib.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,14 @@ def load(path, filename):
6464
if os.path.exists(model_path):
6565
break
6666
assert model_path, "Pytorch state_dict not found at {}".format(model_path)
67-
model_pytorch.load_state_dict(torch.load(model_path,
68-
map_location=map_location))
69-
70-
# state_dict = torch.load(model_path, map_location=map_location)['state_dict']
71-
# new_state_dict = {}
72-
# for k, v in state_dict.items():
73-
# k = str(k).replace('module.', '')
74-
# new_state_dict[k] = v
75-
# model_pytorch.load_state_dict(new_state_dict, strict=False)
67+
try:
68+
model_pytorch.load_state_dict(
69+
torch.load(model_path, map_location=map_location))
70+
except RuntimeError as e:
71+
print("WARNING: Ignored mismatch when loading pytorch state_dict.")
72+
print(e)
73+
model_pytorch.load_state_dict(
74+
torch.load(model_path, map_location=map_location), strict=False)
7675

7776
# Switch from train to eval mode to ensure Dropout / BatchNorm is handled
7877
# correctly.

0 commit comments

Comments
 (0)