88from collections import defaultdict , deque
99from copy import deepcopy
1010from dataclasses import replace
11- from typing import List , Optional , Union , Tuple
11+ from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Sequence , Union , Tuple
1212
1313from ._pretrained import PretrainedCfg , DefaultCfg , split_model_name_tag
1414
1515__all__ = [
1616 'list_models' , 'list_pretrained' , 'is_model' , 'model_entrypoint' , 'list_modules' , 'is_model_in_modules' ,
1717 'get_pretrained_cfg_value' , 'is_model_pretrained' , 'get_arch_name' ]
1818
19- _module_to_models = defaultdict (set ) # dict of sets to check membership of model in module
20- _model_to_module = {} # mapping of model names to module names
21- _model_entrypoints = {} # mapping of model names to architecture entrypoint fns
22- _model_has_pretrained = set () # set of model names that have pretrained weight url present
23- _model_default_cfgs = dict () # central repo for model arch -> default cfg objects
24- _model_pretrained_cfgs = dict () # central repo for model arch.tag -> pretrained cfgs
25- _model_with_tags = defaultdict (list ) # shortcut to map each model arch to all model + tag names
19+ _module_to_models : Dict [ str , Set [ str ]] = defaultdict (set ) # dict of sets to check membership of model in module
20+ _model_to_module : Dict [ str , str ] = {} # mapping of model names to module names
21+ _model_entrypoints : Dict [ str , Callable [..., Any ]] = {} # mapping of model names to architecture entrypoint fns
22+ _model_has_pretrained : Set [ str ] = set () # set of model names that have pretrained weight url present
23+ _model_default_cfgs : Dict [ str , PretrainedCfg ] = {} # central repo for model arch -> default cfg objects
24+ _model_pretrained_cfgs : Dict [ str , PretrainedCfg ] = {} # central repo for model arch.tag -> pretrained cfgs
25+ _model_with_tags : Dict [ str , List [ str ]] = defaultdict (list ) # shortcut to map each model arch to all model + tag names
2626
2727
28- def get_arch_name (model_name : str ) -> Tuple [ str , Optional [ str ]] :
28+ def get_arch_name (model_name : str ) -> str :
2929 return split_model_name_tag (model_name )[0 ]
3030
3131
32- def register_model (fn ) :
32+ def register_model (fn : Callable [..., Any ]) -> Callable [..., Any ] :
3333 # lookup containing module
3434 mod = sys .modules [fn .__module__ ]
3535 module_name_split = fn .__module__ .split ('.' )
@@ -40,7 +40,7 @@ def register_model(fn):
4040 if hasattr (mod , '__all__' ):
4141 mod .__all__ .append (model_name )
4242 else :
43- mod .__all__ = [model_name ]
43+ mod .__all__ = [model_name ] # type: ignore
4444
4545 # add entries to registry dict/sets
4646 _model_entrypoints [model_name ] = fn
@@ -87,28 +87,33 @@ def register_model(fn):
8787 return fn
8888
8989
90- def _natural_key (string_ ):
90+ def _natural_key (string_ : str ) -> List [Union [int , str ]]:
91+ """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
9192 return [int (s ) if s .isdigit () else s for s in re .split (r'(\d+)' , string_ .lower ())]
9293
9394
9495def list_models (
9596 filter : Union [str , List [str ]] = '' ,
9697 module : str = '' ,
97- pretrained = False ,
98- exclude_filters : str = '' ,
98+ pretrained : bool = False ,
99+ exclude_filters : Union [ str , List [ str ]] = '' ,
99100 name_matches_cfg : bool = False ,
100101 include_tags : Optional [bool ] = None ,
101- ):
102+ ) -> List [ str ] :
102103 """ Return list of available model names, sorted alphabetically
103104
104105 Args:
105- filter (str) - Wildcard filter string that works with fnmatch
106- module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
107- pretrained (bool) - Include only models with valid pretrained weights if True
108- exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
109- name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
110- include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
106+ filter - Wildcard filter string that works with fnmatch
107+ module - Limit model selection to a specific submodule (ie 'vision_transformer')
108+ pretrained - Include only models with valid pretrained weights if True
109+ exclude_filters - Wildcard filters to exclude models after including them with filter
110+ name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
111+ include_tags - Include pretrained tags in model names (model.tag). If None, defaults
111112 set to True when pretrained=True else False (default: None)
113+
114+ Returns:
115+ models - The sorted list of models
116+
112117 Example:
113118 model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
114119 model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@@ -118,7 +123,7 @@ def list_models(
118123 include_tags = pretrained
119124
120125 if module :
121- all_models = list (_module_to_models [module ])
126+ all_models : Iterable [ str ] = list (_module_to_models [module ])
122127 else :
123128 all_models = _model_entrypoints .keys ()
124129
@@ -130,36 +135,36 @@ def list_models(
130135 all_models = models_with_tags
131136
132137 if filter :
133- models = []
138+ models : Set [ str ] = set ()
134139 include_filters = filter if isinstance (filter , (tuple , list )) else [filter ]
135140 for f in include_filters :
136141 include_models = fnmatch .filter (all_models , f ) # include these models
137142 if len (include_models ):
138- models = set ( models ) .union (include_models )
143+ models = models .union (include_models )
139144 else :
140- models = all_models
145+ models = set ( all_models )
141146
142147 if exclude_filters :
143148 if not isinstance (exclude_filters , (tuple , list )):
144149 exclude_filters = [exclude_filters ]
145150 for xf in exclude_filters :
146151 exclude_models = fnmatch .filter (models , xf ) # exclude these models
147152 if len (exclude_models ):
148- models = set ( models ) .difference (exclude_models )
153+ models = models .difference (exclude_models )
149154
150155 if pretrained :
151156 models = _model_has_pretrained .intersection (models )
152157
153158 if name_matches_cfg :
154159 models = set (_model_pretrained_cfgs ).intersection (models )
155160
156- return list ( sorted (models , key = _natural_key ) )
161+ return sorted (models , key = _natural_key )
157162
158163
159164def list_pretrained (
160165 filter : Union [str , List [str ]] = '' ,
161166 exclude_filters : str = '' ,
162- ):
167+ ) -> List [ str ] :
163168 return list_models (
164169 filter = filter ,
165170 pretrained = True ,
@@ -168,14 +173,14 @@ def list_pretrained(
168173 )
169174
170175
171- def is_model (model_name ) :
176+ def is_model (model_name : str ) -> bool :
172177 """ Check if a model name exists
173178 """
174179 arch_name = get_arch_name (model_name )
175180 return arch_name in _model_entrypoints
176181
177182
178- def model_entrypoint (model_name , module_filter : Optional [str ] = None ):
183+ def model_entrypoint (model_name : str , module_filter : Optional [str ] = None ) -> Callable [..., Any ] :
179184 """Fetch a model entrypoint for specified model name
180185 """
181186 arch_name = get_arch_name (model_name )
@@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
184189 return _model_entrypoints [arch_name ]
185190
186191
187- def list_modules ():
192+ def list_modules () -> List [ str ] :
188193 """ Return list of module names that contain models / model entrypoints
189194 """
190195 modules = _module_to_models .keys ()
191- return list ( sorted (modules ) )
196+ return sorted (modules )
192197
193198
194- def is_model_in_modules (model_name , module_names ):
199+ def is_model_in_modules (
200+ model_name : str , module_names : Union [Tuple [str , ...], List [str ], Set [str ]]
201+ ) -> bool :
195202 """Check if a model exists within a subset of modules
203+
196204 Args:
197- model_name (str) - name of model to check
198- module_names (tuple, list, set) - names of modules to search in
205+ model_name - name of model to check
206+ module_names - names of modules to search in
199207 """
200208 arch_name = get_arch_name (model_name )
201209 assert isinstance (module_names , (tuple , list , set ))
202210 return any (arch_name in _module_to_models [n ] for n in module_names )
203211
204212
205- def is_model_pretrained (model_name ) :
213+ def is_model_pretrained (model_name : str ) -> bool :
206214 return model_name in _model_has_pretrained
207215
208216
209- def get_pretrained_cfg (model_name , allow_unregistered = True ):
217+ def get_pretrained_cfg (model_name : str , allow_unregistered : bool = True ) -> Optional [ PretrainedCfg ] :
210218 if model_name in _model_pretrained_cfgs :
211219 return deepcopy (_model_pretrained_cfgs [model_name ])
212220 arch_name , tag = split_model_name_tag (model_name )
@@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
219227 raise RuntimeError (f'Model architecture ({ arch_name } ) has no pretrained cfg registered.' )
220228
221229
222- def get_pretrained_cfg_value (model_name , cfg_key ) :
230+ def get_pretrained_cfg_value (model_name : str , cfg_key : str ) -> Optional [ Any ] :
223231 """ Get a specific model default_cfg value by key. None if key doesn't exist.
224232 """
225233 cfg = get_pretrained_cfg (model_name , allow_unregistered = False )
0 commit comments