@@ -107,6 +107,7 @@ def load_custom_pretrained(
107107 pretrained_cfg: Default pretrained model cfg
108108 load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
109109 'load_pretrained' on the model will be called if it exists
110+ cache_dir: Override model checkpoint cache dir for this load
110111 """
111112 pretrained_cfg = pretrained_cfg or getattr (model , 'pretrained_cfg' , None )
112113 if not pretrained_cfg :
@@ -148,12 +149,12 @@ def load_pretrained(
148149
149150 Args:
150151 model: PyTorch module
151- pretrained_cfg: configuration for pretrained weights / target dataset
152- num_classes: number of classes for target model
153- in_chans: number of input chans for target model
152+ pretrained_cfg: Configuration for pretrained weights / target dataset
153+ num_classes: Number of classes for target model. Will adapt pretrained if different.
154+ in_chans: Number of input chans for target model. Will adapt pretrained if different.
154155 filter_fn: state_dict filter fn for load (takes state_dict, model as args)
155- strict: strict load of checkpoint
156- cache_dir: override path to cache dir for this load
156+ strict: Strict load of checkpoint
157+ cache_dir: Override model checkpoint cache dir for this load
157158 """
158159 pretrained_cfg = pretrained_cfg or getattr (model , 'pretrained_cfg' , None )
159160 if not pretrained_cfg :
@@ -326,8 +327,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
326327
327328def resolve_pretrained_cfg (
328329 variant : str ,
329- pretrained_cfg = None ,
330- pretrained_cfg_overlay = None ,
330+ pretrained_cfg : Optional [ Union [ str , Dict [ str , Any ]]] = None ,
331+ pretrained_cfg_overlay : Optional [ Dict [ str , Any ]] = None ,
331332) -> PretrainedCfg :
332333 model_with_tag = variant
333334 pretrained_tag = None
@@ -382,17 +383,18 @@ def build_model_with_cfg(
382383 * pruning config / model adaptation
383384
384385 Args:
385- model_cls: model class
386- variant: model variant name
387- pretrained: load pretrained weights
388- pretrained_cfg: model's pretrained weight/task config
389- model_cfg: model's architecture config
390- feature_cfg: feature extraction adapter config
391- pretrained_strict: load pretrained weights strictly
392- pretrained_filter_fn: filter callable for pretrained weights
393- cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
394- kwargs_filter: kwargs to filter before passing to model
395- **kwargs: model args passed through to model __init__
386+ model_cls: Model class
387+ variant: Model variant name
388+ pretrained: Load the pretrained weights
389+ pretrained_cfg: Model's pretrained weight/task config
390+ pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
391+ model_cfg: Model's architecture config
392+ feature_cfg: Feature extraction adapter config
393+ pretrained_strict: Load pretrained weights strictly
394+ pretrained_filter_fn: Filter callable for pretrained weights
395+ cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
396+ kwargs_filter: Kwargs keys to filter (remove) before passing to model
397+ **kwargs: Model args passed through to model __init__
396398 """
397399 pruned = kwargs .pop ('pruned' , False )
398400 features = False
@@ -404,8 +406,6 @@ def build_model_with_cfg(
404406 pretrained_cfg = pretrained_cfg ,
405407 pretrained_cfg_overlay = pretrained_cfg_overlay
406408 )
407-
408- # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
409409 pretrained_cfg = pretrained_cfg .to_dict ()
410410
411411 _update_default_model_kwargs (pretrained_cfg , kwargs , kwargs_filter )
0 commit comments