33import os
44from copy import deepcopy
55from pathlib import Path
6- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , TypeVar , Union
77
88from torch import nn as nn
99from torch .hub import load_state_dict_from_url
2626_CHECK_HASH = False
2727_USE_OLD_CACHE = int (os .environ .get ('TIMM_USE_OLD_CACHE' , 0 )) > 0
2828
29- __all__ = ['set_pretrained_download_progress' , 'set_pretrained_check_hash' , 'load_custom_pretrained' , 'load_pretrained' ,
30- 'pretrained_cfg_for_features' , 'resolve_pretrained_cfg' , 'build_model_with_cfg' ]
29+ __all__ = [
30+ 'set_pretrained_download_progress' ,
31+ 'set_pretrained_check_hash' ,
32+ 'load_custom_pretrained' ,
33+ 'load_pretrained' ,
34+ 'pretrained_cfg_for_features' ,
35+ 'resolve_pretrained_cfg' ,
36+ 'build_model_with_cfg' ,
37+ ]
3138
3239
33- def _resolve_pretrained_source (pretrained_cfg ):
40+ ModelT = TypeVar ("ModelT" , bound = nn .Module ) # any subclass of nn.Module
41+
42+
43+ def _resolve_pretrained_source (pretrained_cfg : Dict [str , Any ]) -> Tuple [str , str ]:
3444 cfg_source = pretrained_cfg .get ('source' , '' )
3545 pretrained_url = pretrained_cfg .get ('url' , None )
3646 pretrained_file = pretrained_cfg .get ('file' , None )
@@ -78,25 +88,25 @@ def _resolve_pretrained_source(pretrained_cfg):
7888 return load_from , pretrained_loc
7989
8090
81- def set_pretrained_download_progress (enable = True ):
91+ def set_pretrained_download_progress (enable : bool = True ) -> None :
8292 """ Set download progress for pretrained weights on/off (globally). """
8393 global _DOWNLOAD_PROGRESS
8494 _DOWNLOAD_PROGRESS = enable
8595
8696
87- def set_pretrained_check_hash (enable = True ):
97+ def set_pretrained_check_hash (enable : bool = True ) -> None :
8898 """ Set hash checking for pretrained weights on/off (globally). """
8999 global _CHECK_HASH
90100 _CHECK_HASH = enable
91101
92102
93103def load_custom_pretrained (
94104 model : nn .Module ,
95- pretrained_cfg : Optional [Dict ] = None ,
105+ pretrained_cfg : Optional [Dict [ str , Any ] ] = None ,
96106 load_fn : Optional [Callable ] = None ,
97107 cache_dir : Optional [Union [str , Path ]] = None ,
98- ):
99- r """Loads a custom (read non .pth) weight file
108+ ) -> None :
109+ """Loads a custom (read non .pth) weight file
100110
101111 Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
102112 a passed in custom load fun, or the `load_pretrained` model member fn.
@@ -141,13 +151,13 @@ def load_custom_pretrained(
141151
142152def load_pretrained (
143153 model : nn .Module ,
144- pretrained_cfg : Optional [Dict ] = None ,
154+ pretrained_cfg : Optional [Dict [ str , Any ] ] = None ,
145155 num_classes : int = 1000 ,
146156 in_chans : int = 3 ,
147157 filter_fn : Optional [Callable ] = None ,
148158 strict : bool = True ,
149159 cache_dir : Optional [Union [str , Path ]] = None ,
150- ):
160+ ) -> None :
151161 """ Load pretrained checkpoint
152162
153163 Args:
@@ -278,7 +288,7 @@ def load_pretrained(
278288 f' This may be expected if model is being adapted.' )
279289
280290
281- def pretrained_cfg_for_features (pretrained_cfg ) :
291+ def pretrained_cfg_for_features (pretrained_cfg : Dict [ str , Any ]) -> Dict [ str , Any ] :
282292 pretrained_cfg = deepcopy (pretrained_cfg )
283293 # remove default pretrained cfg fields that don't have much relevance for feature backbone
284294 to_remove = ('num_classes' , 'classifier' , 'global_pool' ) # add default final pool size?
@@ -287,14 +297,14 @@ def pretrained_cfg_for_features(pretrained_cfg):
287297 return pretrained_cfg
288298
289299
290- def _filter_kwargs (kwargs , names ) :
300+ def _filter_kwargs (kwargs : Dict [ str , Any ], names : List [ str ]) -> None :
291301 if not kwargs or not names :
292302 return
293303 for n in names :
294304 kwargs .pop (n , None )
295305
296306
297- def _update_default_model_kwargs (pretrained_cfg , kwargs , kwargs_filter ):
307+ def _update_default_model_kwargs (pretrained_cfg , kwargs , kwargs_filter ) -> None :
298308 """ Update the default_cfg and kwargs before passing to model
299309
300310 Args:
@@ -340,6 +350,7 @@ def resolve_pretrained_cfg(
340350 pretrained_cfg : Optional [Union [str , Dict [str , Any ]]] = None ,
341351 pretrained_cfg_overlay : Optional [Dict [str , Any ]] = None ,
342352) -> PretrainedCfg :
353+ """Resolve pretrained configuration from various sources."""
343354 model_with_tag = variant
344355 pretrained_tag = None
345356 if pretrained_cfg :
@@ -371,7 +382,7 @@ def resolve_pretrained_cfg(
371382
372383
373384def build_model_with_cfg (
374- model_cls : Callable ,
385+ model_cls : Union [ Type [ ModelT ], Callable [..., ModelT ]] ,
375386 variant : str ,
376387 pretrained : bool ,
377388 pretrained_cfg : Optional [Dict ] = None ,
@@ -383,7 +394,7 @@ def build_model_with_cfg(
383394 cache_dir : Optional [Union [str , Path ]] = None ,
384395 kwargs_filter : Optional [Tuple [str ]] = None ,
385396 ** kwargs ,
386- ):
397+ ) -> ModelT :
387398 """ Build model with specified default_cfg and optional model_cfg
388399
389400 This helper fn aids in the construction of a model including:
0 commit comments