Skip to content

Commit 9cc289f

Browse files
committed
Exclude EfficientNet-L2 models from test
1 parent e545bb9 commit 9cc289f

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

tests/test_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@pytest.mark.timeout(300)
8-
@pytest.mark.parametrize('model_name', list_models())
8+
@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
99
@pytest.mark.parametrize('batch_size', [1])
1010
def test_model_forward(model_name, batch_size):
1111
"""Run a single forward pass with each model"""

timm/models/registry.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ def _natural_key(string_):
4242
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
4343

4444

45-
def list_models(filter='', module='', pretrained=False):
45+
def list_models(filter='', module='', pretrained=False, exclude_filters=''):
4646
""" Return list of available model names, sorted alphabetically
4747
4848
Args:
4949
filter (str) - Wildcard filter string that works with fnmatch
5050
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
51+
pretrained (bool) - Include only models with pretrained weights if True
52+
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
5153
5254
Example:
5355
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False):
5860
else:
5961
models = _model_entrypoints.keys()
6062
if filter:
61-
models = fnmatch.filter(models, filter)
63+
models = fnmatch.filter(models, filter) # include these models
64+
if exclude_filters:
65+
if not isinstance(exclude_filters, list):
66+
exclude_filters = [exclude_filters]
67+
for xf in exclude_filters:
68+
exclude_models = fnmatch.filter(models, xf) # exclude these models
69+
if len(exclude_models):
70+
models = set(models).difference(exclude_models)
6271
if pretrained:
6372
models = _model_has_pretrained.intersection(models)
6473
return list(sorted(models, key=_natural_key))

0 commit comments

Comments
 (0)