22import logging
33import os
44from copy import deepcopy
5+ from pathlib import Path
56from typing import Any , Callable , Dict , Optional , Tuple
67from contextlib import nullcontext
78
@@ -92,6 +93,7 @@ def load_custom_pretrained(
9293 model : nn .Module ,
9394 pretrained_cfg : Optional [Dict ] = None ,
9495 load_fn : Optional [Callable ] = None ,
96+ cache_dir : Optional [Union [str , Path ]] = None ,
9597):
9698 r"""Loads a custom (read non .pth) weight file
9799
@@ -104,9 +106,10 @@ def load_custom_pretrained(
104106
105107 Args:
106108 model: The instantiated model to load weights into
107- pretrained_cfg (dict) : Default pretrained model cfg
109+ pretrained_cfg: Default pretrained model cfg
108110 load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
109- 'laod_pretrained' on the model will be called if it exists
111+ 'load_pretrained' on the model will be called if it exists
112+ cache_dir: Override model checkpoint cache dir for this load
110113 """
111114 pretrained_cfg = pretrained_cfg or getattr (model , 'pretrained_cfg' , None )
112115 if not pretrained_cfg :
@@ -124,6 +127,7 @@ def load_custom_pretrained(
124127 pretrained_loc ,
125128 check_hash = _CHECK_HASH ,
126129 progress = _DOWNLOAD_PROGRESS ,
130+ cache_dir = cache_dir ,
127131 )
128132
129133 if load_fn is not None :
@@ -141,17 +145,18 @@ def load_pretrained(
141145 in_chans : int = 3 ,
142146 filter_fn : Optional [Callable ] = None ,
143147 strict : bool = True ,
148+ cache_dir : Optional [Union [str , Path ]] = None ,
144149):
145150 """ Load pretrained checkpoint
146151
147152 Args:
148- model (nn.Module) : PyTorch model module
149- pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
150- num_classes (int): num_classes for target model
151- in_chans (int): in_chans for target model
152- filter_fn (Optional[Callable]) : state_dict filter fn for load (takes state_dict, model as args)
153- strict (bool): strict load of checkpoint
154-
153+ model: PyTorch module
154+ pretrained_cfg: Configuration for pretrained weights / target dataset
155+ num_classes: Number of classes for target model. Will adapt pretrained if different.
156+ in_chans: Number of input chans for target model. Will adapt pretrained if different.
157+ filter_fn: state_dict filter fn for load (takes state_dict, model as args)
158+ strict: Strict load of checkpoint
159+ cache_dir: Override model checkpoint cache dir for this load
155160 """
156161 pretrained_cfg = pretrained_cfg or getattr (model , 'pretrained_cfg' , None )
157162 if not pretrained_cfg :
@@ -175,6 +180,7 @@ def load_pretrained(
175180 pretrained_loc ,
176181 progress = _DOWNLOAD_PROGRESS ,
177182 check_hash = _CHECK_HASH ,
183+ cache_dir = cache_dir ,
178184 )
179185 model .load_pretrained (pretrained_loc )
180186 return
@@ -186,25 +192,27 @@ def load_pretrained(
186192 progress = _DOWNLOAD_PROGRESS ,
187193 check_hash = _CHECK_HASH ,
188194 weights_only = True ,
195+ model_dir = cache_dir ,
189196 )
190197 except TypeError :
191198 state_dict = load_state_dict_from_url (
192199 pretrained_loc ,
193200 map_location = 'cpu' ,
194201 progress = _DOWNLOAD_PROGRESS ,
195202 check_hash = _CHECK_HASH ,
203+ model_dir = cache_dir ,
196204 )
197205 elif load_from == 'hf-hub' :
198206 _logger .info (f'Loading pretrained weights from Hugging Face hub ({ pretrained_loc } )' )
199207 if isinstance (pretrained_loc , (list , tuple )):
200208 custom_load = pretrained_cfg .get ('custom_load' , False )
201209 if isinstance (custom_load , str ) and custom_load == 'hf' :
202- load_custom_from_hf (* pretrained_loc , model )
210+ load_custom_from_hf (* pretrained_loc , model , cache_dir = cache_dir )
203211 return
204212 else :
205- state_dict = load_state_dict_from_hf (* pretrained_loc )
213+ state_dict = load_state_dict_from_hf (* pretrained_loc , cache_dir = cache_dir )
206214 else :
207- state_dict = load_state_dict_from_hf (pretrained_loc , weights_only = True )
215+ state_dict = load_state_dict_from_hf (pretrained_loc , weights_only = True , cache_dir = cache_dir )
208216 else :
209217 model_name = pretrained_cfg .get ('architecture' , 'this model' )
210218 raise RuntimeError (f"No pretrained weights exist for { model_name } . Use `pretrained=False` for random init." )
@@ -321,8 +329,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
321329
322330def resolve_pretrained_cfg (
323331 variant : str ,
324- pretrained_cfg = None ,
325- pretrained_cfg_overlay = None ,
332+ pretrained_cfg : Optional [ Union [ str , Dict [ str , Any ]]] = None ,
333+ pretrained_cfg_overlay : Optional [ Dict [ str , Any ]] = None ,
326334) -> PretrainedCfg :
327335 model_with_tag = variant
328336 pretrained_tag = None
@@ -364,6 +372,7 @@ def build_model_with_cfg(
364372 feature_cfg : Optional [Dict ] = None ,
365373 pretrained_strict : bool = True ,
366374 pretrained_filter_fn : Optional [Callable ] = None ,
375+ cache_dir : Optional [Union [str , Path ]] = None ,
367376 kwargs_filter : Optional [Tuple [str ]] = None ,
368377 ** kwargs ,
369378):
@@ -376,16 +385,18 @@ def build_model_with_cfg(
376385 * pruning config / model adaptation
377386
378387 Args:
379- model_cls: model class
380- variant: model variant name
381- pretrained: load pretrained weights
382- pretrained_cfg: model's pretrained weight/task config
383- model_cfg: model's architecture config
384- feature_cfg: feature extraction adapter config
385- pretrained_strict: load pretrained weights strictly
386- pretrained_filter_fn: filter callable for pretrained weights
387- kwargs_filter: kwargs to filter before passing to model
388- **kwargs: model args passed through to model __init__
388+ model_cls: Model class
389+ variant: Model variant name
390+ pretrained: Load the pretrained weights
391+ pretrained_cfg: Model's pretrained weight/task config
392+ pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
393+ model_cfg: Model's architecture config
394+ feature_cfg: Feature extraction adapter config
395+ pretrained_strict: Load pretrained weights strictly
396+ pretrained_filter_fn: Filter callable for pretrained weights
397+ cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
398+ kwargs_filter: Kwargs keys to filter (remove) before passing to model
399+ **kwargs: Model args passed through to model __init__
389400 """
390401 pruned = kwargs .pop ('pruned' , False )
391402 features = False
@@ -397,8 +408,6 @@ def build_model_with_cfg(
397408 pretrained_cfg = pretrained_cfg ,
398409 pretrained_cfg_overlay = pretrained_cfg_overlay
399410 )
400-
401- # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
402411 pretrained_cfg = pretrained_cfg .to_dict ()
403412
404413 _update_default_model_kwargs (pretrained_cfg , kwargs , kwargs_filter )
@@ -437,6 +446,7 @@ def build_model_with_cfg(
437446 in_chans = kwargs .get ('in_chans' , 3 ),
438447 filter_fn = pretrained_filter_fn ,
439448 strict = pretrained_strict ,
449+ cache_dir = cache_dir ,
440450 )
441451
442452 # Wrap the model in a feature extraction module if enabled
0 commit comments