55from functools import partial
66from pathlib import Path
77from tempfile import TemporaryDirectory
8- from typing import Iterable , Optional , Union
8+ from typing import Iterable , List , Optional , Tuple , Union
99
1010import torch
1111from torch .hub import HASH_REGEX , download_url_to_file , urlparse
5353HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
5454
5555
56- def get_cache_dir (child_dir = '' ):
56+ def get_cache_dir (child_dir : str = '' ):
5757 """
5858 Returns the location of the directory where models are cached (and creates it if necessary).
5959 """
@@ -68,13 +68,22 @@ def get_cache_dir(child_dir=''):
6868 return model_dir
6969
7070
71- def download_cached_file (url , check_hash = True , progress = False ):
71+ def download_cached_file (
72+ url : Union [str , List [str ], Tuple [str , str ]],
73+ check_hash : bool = True ,
74+ progress : bool = False ,
75+ cache_dir : Optional [Union [str , Path ]] = None ,
76+ ):
7277 if isinstance (url , (list , tuple )):
7378 url , filename = url
7479 else :
7580 parts = urlparse (url )
7681 filename = os .path .basename (parts .path )
77- cached_file = os .path .join (get_cache_dir (), filename )
82+ if cache_dir :
83+ os .makedirs (cache_dir , exist_ok = True )
84+ else :
85+ cache_dir = get_cache_dir ()
86+ cached_file = os .path .join (cache_dir , filename )
7887 if not os .path .exists (cached_file ):
7988 _logger .info ('Downloading: "{}" to {}\n ' .format (url , cached_file ))
8089 hash_prefix = None
@@ -85,13 +94,19 @@ def download_cached_file(url, check_hash=True, progress=False):
8594 return cached_file
8695
8796
88- def check_cached_file (url , check_hash = True ):
97+ def check_cached_file (
98+ url : Union [str , List [str ], Tuple [str , str ]],
99+ check_hash : bool = True ,
100+ cache_dir : Optional [Union [str , Path ]] = None ,
101+ ):
89102 if isinstance (url , (list , tuple )):
90103 url , filename = url
91104 else :
92105 parts = urlparse (url )
93106 filename = os .path .basename (parts .path )
94- cached_file = os .path .join (get_cache_dir (), filename )
107+ if not cache_dir :
108+ cache_dir = get_cache_dir ()
109+ cached_file = os .path .join (cache_dir , filename )
95110 if os .path .exists (cached_file ):
96111 if check_hash :
97112 r = HASH_REGEX .search (filename ) # r is Optional[Match[str]]
@@ -105,7 +120,7 @@ def check_cached_file(url, check_hash=True):
105120 return False
106121
107122
108- def has_hf_hub (necessary = False ):
123+ def has_hf_hub (necessary : bool = False ):
109124 if not _has_hf_hub and necessary :
110125 # if no HF Hub module installed, and it is necessary to continue, raise error
111126 raise RuntimeError (
@@ -122,20 +137,32 @@ def hf_split(hf_id: str):
122137 return hf_model_id , hf_revision
123138
124139
125- def load_cfg_from_json (json_file : Union [str , os . PathLike ]):
140+ def load_cfg_from_json (json_file : Union [str , Path ]):
126141 with open (json_file , "r" , encoding = "utf-8" ) as reader :
127142 text = reader .read ()
128143 return json .loads (text )
129144
130145
131- def download_from_hf (model_id : str , filename : str ):
146+ def download_from_hf (
147+ model_id : str ,
148+ filename : str ,
149+ cache_dir : Optional [Union [str , Path ]] = None ,
150+ ):
132151 hf_model_id , hf_revision = hf_split (model_id )
133- return hf_hub_download (hf_model_id , filename , revision = hf_revision )
152+ return hf_hub_download (
153+ hf_model_id ,
154+ filename ,
155+ revision = hf_revision ,
156+ cache_dir = cache_dir ,
157+ )
134158
135159
136- def load_model_config_from_hf (model_id : str ):
160+ def load_model_config_from_hf (
161+ model_id : str ,
162+ cache_dir : Optional [Union [str , Path ]] = None ,
163+ ):
137164 assert has_hf_hub (True )
138- cached_file = download_from_hf (model_id , 'config.json' )
165+ cached_file = download_from_hf (model_id , 'config.json' , cache_dir = cache_dir )
139166
140167 hf_config = load_cfg_from_json (cached_file )
141168 if 'pretrained_cfg' not in hf_config :
@@ -172,6 +199,7 @@ def load_state_dict_from_hf(
172199 model_id : str ,
173200 filename : str = HF_WEIGHTS_NAME ,
174201 weights_only : bool = False ,
202+ cache_dir : Optional [Union [str , Path ]] = None ,
175203):
176204 assert has_hf_hub (True )
177205 hf_model_id , hf_revision = hf_split (model_id )
@@ -180,7 +208,12 @@ def load_state_dict_from_hf(
180208 if _has_safetensors :
181209 for safe_filename in _get_safe_alternatives (filename ):
182210 try :
183- cached_safe_file = hf_hub_download (repo_id = hf_model_id , filename = safe_filename , revision = hf_revision )
211+ cached_safe_file = hf_hub_download (
212+ repo_id = hf_model_id ,
213+ filename = safe_filename ,
214+ revision = hf_revision ,
215+ cache_dir = cache_dir ,
216+ )
184217 _logger .info (
185218 f"[{ model_id } ] Safe alternative available for '{ filename } ' "
186219 f"(as '{ safe_filename } '). Loading weights using safetensors." )
@@ -189,7 +222,12 @@ def load_state_dict_from_hf(
189222 pass
190223
191224 # Otherwise, load using pytorch.load
192- cached_file = hf_hub_download (hf_model_id , filename = filename , revision = hf_revision )
225+ cached_file = hf_hub_download (
226+ hf_model_id ,
227+ filename = filename ,
228+ revision = hf_revision ,
229+ cache_dir = cache_dir ,
230+ )
193231 _logger .debug (f"[{ model_id } ] Safe alternative not found for '{ filename } '. Loading weights using default pytorch." )
194232 try :
195233 state_dict = torch .load (cached_file , map_location = 'cpu' , weights_only = weights_only )
@@ -198,15 +236,25 @@ def load_state_dict_from_hf(
198236 return state_dict
199237
200238
201- def load_custom_from_hf (model_id : str , filename : str , model : torch .nn .Module ):
239+ def load_custom_from_hf (
240+ model_id : str ,
241+ filename : str ,
242+ model : torch .nn .Module ,
243+ cache_dir : Optional [Union [str , Path ]] = None ,
244+ ):
202245 assert has_hf_hub (True )
203246 hf_model_id , hf_revision = hf_split (model_id )
204- cached_file = hf_hub_download (hf_model_id , filename = filename , revision = hf_revision )
247+ cached_file = hf_hub_download (
248+ hf_model_id ,
249+ filename = filename ,
250+ revision = hf_revision ,
251+ cache_dir = cache_dir ,
252+ )
205253 return model .load_pretrained (cached_file )
206254
207255
208256def save_config_for_hf (
209- model ,
257+ model : torch . nn . Module ,
210258 config_path : str ,
211259 model_config : Optional [dict ] = None ,
212260 model_args : Optional [dict ] = None
@@ -255,7 +303,7 @@ def save_config_for_hf(
255303
256304
257305def save_for_hf (
258- model ,
306+ model : torch . nn . Module ,
259307 save_directory : str ,
260308 model_config : Optional [dict ] = None ,
261309 model_args : Optional [dict ] = None ,
0 commit comments