22import json
33import logging
44import os
5+ import sys
56from functools import partial
67from pathlib import Path
78from tempfile import TemporaryDirectory
8- from typing import Optional , Union
9-
9+ from typing import Iterable , Optional , Union
1010import torch
1111from torch .hub import HASH_REGEX , download_url_to_file , urlparse
12+ import safetensors .torch
1213
1314try :
1415 from torch .hub import get_dir
1516except ImportError :
1617 from torch .hub import _get_torch_home as get_dir
1718
19+ if sys .version_info >= (3 , 8 ):
20+ from typing import Literal
21+ else :
22+ from typing_extensions import Literal
23+
1824from timm import __version__
1925from timm .models ._pretrained import filter_pretrained_cfg
2026
3541__all__ = ['get_cache_dir' , 'download_cached_file' , 'has_hf_hub' , 'hf_split' , 'load_model_config_from_hf' ,
3642 'load_state_dict_from_hf' , 'save_for_hf' , 'push_to_hf_hub' ]
3743
44+ # Default name for a weights file hosted on the Huggingface Hub.
45+ HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
46+ HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
3847
3948def get_cache_dir (child_dir = '' ):
4049 """
@@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
150159 return pretrained_cfg , model_name
151160
152161
153- def load_state_dict_from_hf (model_id : str , filename : str = 'pytorch_model.bin' ):
162+ def load_state_dict_from_hf (model_id : str , filename : str = HF_WEIGHTS_NAME ):
154163 assert has_hf_hub (True )
155- cached_file = download_from_hf (model_id , filename )
156- state_dict = torch .load (cached_file , map_location = 'cpu' )
157- return state_dict
164+ hf_model_id , hf_revision = hf_split (model_id )
165+
166+ # Look for .safetensors alternatives and load from it if it exists
167+ for safe_filename in _get_safe_alternatives (filename ):
168+ try :
169+ cached_safe_file = hf_hub_download (repo_id = hf_model_id , filename = safe_filename , revision = hf_revision )
170+ _logger .warning (f"[{ model_id } ] Safe alternative available for '{ filename } ' (as '{ safe_filename } '). Loading weights using safetensors." )
171+ return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
172+ except EntryNotFoundError :
173+ pass
174+
175+ # Otherwise, load using pytorch.load
176+ cached_file = hf_hub_download (hf_model_id , filename = filename , revision = hf_revision )
177+ _logger .warning (f"[{ model_id } ] Safe alternative not found for '{ filename } '. Loading weights using default pytorch." )
178+ return torch .load (cached_file , map_location = 'cpu' )
158179
159180
160181def save_config_for_hf (model , config_path : str , model_config : Optional [dict ] = None ):
@@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
195216 json .dump (hf_config , f , indent = 2 )
196217
197218
198- def save_for_hf (model , save_directory : str , model_config : Optional [dict ] = None ):
219+ def save_for_hf (
220+ model ,
221+ save_directory : str ,
222+ model_config : Optional [dict ] = None ,
223+ safe_serialization : Union [bool , Literal ["both" ]] = False
224+ ):
199225 assert has_hf_hub (True )
200226 save_directory = Path (save_directory )
201227 save_directory .mkdir (exist_ok = True , parents = True )
202228
203- weights_path = save_directory / 'pytorch_model.bin'
204- torch .save (model .state_dict (), weights_path )
229+ # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
230+ tensors = model .state_dict ()
231+ if safe_serialization is True or safe_serialization == "both" :
232+ safetensors .torch .save_file (tensors , save_directory / HF_SAFE_WEIGHTS_NAME )
233+ if safe_serialization is False or safe_serialization == "both" :
234+ torch .save (tensors , save_directory / HF_WEIGHTS_NAME )
205235
206236 config_path = save_directory / 'config.json'
207237 save_config_for_hf (model , config_path , model_config = model_config )
@@ -217,7 +247,15 @@ def push_to_hf_hub(
217247 create_pr : bool = False ,
218248 model_config : Optional [dict ] = None ,
219249 model_card : Optional [dict ] = None ,
250+ safe_serialization : Union [bool , Literal ["both" ]] = False
220251):
252+ """
253+ Arguments:
254+ (...)
255+ safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
256+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
257+ Can be set to `"both"` in order to push both safe and unsafe weights.
258+ """
221259 # Create repo if it doesn't exist yet
222260 repo_url = create_repo (repo_id , token = token , private = private , exist_ok = True )
223261
@@ -236,7 +274,7 @@ def push_to_hf_hub(
236274 # Dump model and push to Hub
237275 with TemporaryDirectory () as tmpdir :
238276 # Save model weights and config.
239- save_for_hf (model , tmpdir , model_config = model_config )
277+ save_for_hf (model , tmpdir , model_config = model_config , safe_serialization = safe_serialization )
240278
241279 # Add readme if it does not exist
242280 if not has_readme :
@@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
302340 for c in citations :
303341 readme_text += f"```bibtex\n { c } \n ```\n "
304342 return readme_text
343+
344+ def _get_safe_alternatives (filename : str ) -> Iterable [str ]:
345+ """Returns potential safetensors alternatives for a given filename.
346+
347+ Use case:
348+ When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
349+ Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
350+ """
351+ if filename == HF_WEIGHTS_NAME :
352+ yield HF_SAFE_WEIGHTS_NAME
353+ if filename .endswith (".bin" ):
354+ yield filename [:- 4 ] + ".safetensors"
0 commit comments