@@ -42,12 +42,14 @@ def _natural_key(string_):
4242 return [int (s ) if s .isdigit () else s for s in re .split (r'(\d+)' , string_ .lower ())]
4343
4444
45- def list_models (filter = '' , module = '' , pretrained = False ):
45+ def list_models (filter = '' , module = '' , pretrained = False , exclude_filters = '' ):
4646 """ Return list of available model names, sorted alphabetically
4747
4848 Args:
4949 filter (str) - Wildcard filter string that works with fnmatch
5050 module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
51+ pretrained (bool) - Include only models with pretrained weights if True
52+ exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
5153
5254 Example:
5355 model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False):
5860 else :
5961 models = _model_entrypoints .keys ()
6062 if filter :
61- models = fnmatch .filter (models , filter )
63+ models = fnmatch .filter (models , filter ) # include these models
64+ if exclude_filters :
65+ if not isinstance (exclude_filters , list ):
66+ exclude_filters = [exclude_filters ]
67+ for xf in exclude_filters :
68+ exclude_models = fnmatch .filter (models , xf ) # exclude these models
69+ if len (exclude_models ):
70+ models = set (models ).difference (exclude_models )
6271 if pretrained :
6372 models = _model_has_pretrained .intersection (models )
6473 return list (sorted (models , key = _natural_key ))
0 commit comments