Skip to content

Commit 8470e29

Browse files
testbotrwightman
authored andcommitted
Add support to load safetensors weights
1 parent f35d6ea commit 8470e29

File tree

5 files changed

+105
-26
lines changed

5 files changed

+105
-26
lines changed

avg_checkpoints.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,26 @@
1717
import glob
1818
import hashlib
1919
from timm.models import load_state_dict
20+
import safetensors.torch
21+
22+
DEFAULT_OUTPUT = "./average.pth"
23+
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
2024

2125
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
2226
parser.add_argument('--input', default='', type=str, metavar='PATH',
2327
help='path to base input folder containing checkpoints')
2428
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
2529
help='checkpoint filter (path wildcard)')
26-
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
27-
help='output filename')
30+
parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
31+
help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
2832
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
2933
help='Force not using ema version of weights (if present)')
3034
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
3135
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
3236
parser.add_argument('-n', type=int, default=10, metavar='N',
3337
help='Number of checkpoints to average')
34-
38+
parser.add_argument('--safetensors', action='store_true',
39+
help='Save weights using safetensors instead of the default torch way (pickle).')
3540

3641
def checkpoint_metric(checkpoint_path):
3742
if not checkpoint_path or not os.path.isfile(checkpoint_path):
@@ -55,6 +60,15 @@ def main():
5560
# by default sort by checkpoint metric (if present) and avg top n checkpoints
5661
args.sort = not args.no_sort
5762

63+
if args.safetensors and args.output == DEFAULT_OUTPUT:
64+
# Default path changes if using safetensors
65+
args.output = DEFAULT_SAFE_OUTPUT
66+
if args.safetensors and not args.output.endswith(".safetensors"):
67+
print(
68+
"Warning: saving weights as safetensors but output file extension is not "
69+
f"set to '.safetensors': {args.output}"
70+
)
71+
5872
if os.path.exists(args.output):
5973
print("Error: Output filename ({}) already exists.".format(args.output))
6074
exit(1)
@@ -107,10 +121,13 @@ def main():
107121
v = v.clamp(float32_info.min, float32_info.max)
108122
final_state_dict[k] = v.to(dtype=torch.float32)
109123

110-
try:
111-
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
112-
except:
113-
torch.save(final_state_dict, args.output)
124+
if args.safetensors:
125+
safetensors.torch.save_file(final_state_dict, args.output)
126+
else:
127+
try:
128+
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
129+
except:
130+
torch.save(final_state_dict, args.output)
114131

115132
with open(args.output, 'rb') as f:
116133
sha_hash = hashlib.sha256(f.read()).hexdigest()

clean_checkpoint.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import argparse
1212
import os
1313
import hashlib
14+
import safetensors.torch
1415
import shutil
15-
from collections import OrderedDict
1616
from timm.models import load_state_dict
1717

1818
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
@@ -24,6 +24,8 @@
2424
help='use ema version of weights if present')
2525
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
2626
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
27+
parser.add_argument('--safetensors', action='store_true',
28+
help='Save weights using safetensors instead of the default torch way (pickle).')
2729

2830
_TEMP_NAME = './_checkpoint.pth'
2931

@@ -35,10 +37,10 @@ def main():
3537
print("Error: Output filename ({}) already exists.".format(args.output))
3638
exit(1)
3739

38-
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn)
40+
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
3941

4042

41-
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
43+
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
4244
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
4345
if checkpoint and os.path.isfile(checkpoint):
4446
print("=> Loading checkpoint '{}'".format(checkpoint))
@@ -53,10 +55,13 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
5355
new_state_dict[name] = v
5456
print("=> Loaded state_dict from '{}'".format(checkpoint))
5557

56-
try:
57-
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
58-
except:
59-
torch.save(new_state_dict, _TEMP_NAME)
58+
if safe_serialization:
59+
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
60+
else:
61+
try:
62+
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
63+
except:
64+
torch.save(new_state_dict, _TEMP_NAME)
6065

6166
with open(_TEMP_NAME, 'rb') as f:
6267
sha_hash = hashlib.sha256(f.read()).hexdigest()
@@ -67,7 +72,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
6772
else:
6873
checkpoint_root = ''
6974
checkpoint_base = os.path.splitext(checkpoint)[0]
70-
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
75+
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
7176
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
7277
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
7378
return final_filename

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ torch>=1.7
22
torchvision
33
pyyaml
44
huggingface_hub
5+
safetensors>=0.2

timm/models/_helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import OrderedDict
88

99
import torch
10+
import safetensors.torch
1011

1112
import timm.models._builder
1213

@@ -26,7 +27,12 @@ def clean_state_dict(state_dict):
2627

2728
def load_state_dict(checkpoint_path, use_ema=True):
2829
if checkpoint_path and os.path.isfile(checkpoint_path):
29-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
30+
# Check if safetensors or not and load weights accordingly
31+
if str(checkpoint_path).endswith(".safetensors"):
32+
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
33+
else:
34+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
35+
3036
state_dict_key = ''
3137
if isinstance(checkpoint, dict):
3238
if use_ema and checkpoint.get('state_dict_ema', None) is not None:

timm/models/_hub.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,25 @@
22
import json
33
import logging
44
import os
5+
import sys
56
from functools import partial
67
from pathlib import Path
78
from tempfile import TemporaryDirectory
8-
from typing import Optional, Union
9-
9+
from typing import Iterable, Optional, Union
1010
import torch
1111
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
12+
import safetensors.torch
1213

1314
try:
1415
from torch.hub import get_dir
1516
except ImportError:
1617
from torch.hub import _get_torch_home as get_dir
1718

19+
if sys.version_info >= (3, 8):
20+
from typing import Literal
21+
else:
22+
from typing_extensions import Literal
23+
1824
from timm import __version__
1925
from timm.models._pretrained import filter_pretrained_cfg
2026

@@ -35,6 +41,9 @@
3541
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
3642
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
3743

44+
# Default name for a weights file hosted on the Huggingface Hub.
45+
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
46+
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
3847

3948
def get_cache_dir(child_dir=''):
4049
"""
@@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
150159
return pretrained_cfg, model_name
151160

152161

153-
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
162+
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
154163
assert has_hf_hub(True)
155-
cached_file = download_from_hf(model_id, filename)
156-
state_dict = torch.load(cached_file, map_location='cpu')
157-
return state_dict
164+
hf_model_id, hf_revision = hf_split(model_id)
165+
166+
# Look for .safetensors alternatives and load from it if it exists
167+
for safe_filename in _get_safe_alternatives(filename):
168+
try:
169+
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
170+
_logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
171+
return safetensors.torch.load_file(cached_safe_file, device="cpu")
172+
except EntryNotFoundError:
173+
pass
174+
175+
# Otherwise, load using pytorch.load
176+
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
177+
_logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
178+
return torch.load(cached_file, map_location='cpu')
158179

159180

160181
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
@@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
195216
json.dump(hf_config, f, indent=2)
196217

197218

198-
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
219+
def save_for_hf(
220+
model,
221+
save_directory: str,
222+
model_config: Optional[dict] = None,
223+
safe_serialization: Union[bool, Literal["both"]] = False
224+
):
199225
assert has_hf_hub(True)
200226
save_directory = Path(save_directory)
201227
save_directory.mkdir(exist_ok=True, parents=True)
202228

203-
weights_path = save_directory / 'pytorch_model.bin'
204-
torch.save(model.state_dict(), weights_path)
229+
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
230+
tensors = model.state_dict()
231+
if safe_serialization is True or safe_serialization == "both":
232+
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
233+
if safe_serialization is False or safe_serialization == "both":
234+
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
205235

206236
config_path = save_directory / 'config.json'
207237
save_config_for_hf(model, config_path, model_config=model_config)
@@ -217,7 +247,15 @@ def push_to_hf_hub(
217247
create_pr: bool = False,
218248
model_config: Optional[dict] = None,
219249
model_card: Optional[dict] = None,
250+
safe_serialization: Union[bool, Literal["both"]] = False
220251
):
252+
"""
253+
Arguments:
254+
(...)
255+
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
256+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
257+
Can be set to `"both"` in order to push both safe and unsafe weights.
258+
"""
221259
# Create repo if it doesn't exist yet
222260
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
223261

@@ -236,7 +274,7 @@ def push_to_hf_hub(
236274
# Dump model and push to Hub
237275
with TemporaryDirectory() as tmpdir:
238276
# Save model weights and config.
239-
save_for_hf(model, tmpdir, model_config=model_config)
277+
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
240278

241279
# Add readme if it does not exist
242280
if not has_readme:
@@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
302340
for c in citations:
303341
readme_text += f"```bibtex\n{c}\n```\n"
304342
return readme_text
343+
344+
def _get_safe_alternatives(filename: str) -> Iterable[str]:
345+
"""Returns potential safetensors alternatives for a given filename.
346+
347+
Use case:
348+
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
349+
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
350+
"""
351+
if filename == HF_WEIGHTS_NAME:
352+
yield HF_SAFE_WEIGHTS_NAME
353+
if filename.endswith(".bin"):
354+
yield filename[:-4] + ".safetensors"

0 commit comments

Comments
 (0)