Skip to content

Commit ccfe201

Browse files
committed
tests parameters refactor
1 parent 246a185 commit ccfe201

File tree

9 files changed

+32
-86
lines changed

9 files changed

+32
-86
lines changed

tests/parameters.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from torch import nn
2+
3+
4+
def value_name(value) -> str: # pragma: no cover
5+
name = getattr(value, "__name__", None)
6+
if name is not None:
7+
return name
8+
if isinstance(value, nn.Module):
9+
return value._get_name() # pylint: disable=W0212
10+
return value
11+
12+
13+
def ids_fn(key, value):
14+
return [f"{key[:2]}_{value_name(v)}" for v in value]

tests/test_Net.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import torch
2-
import torch.nn as nn
32

43
from model_constructor.net import Net, NewResBlock, ResBlock
54

6-
# from model_constructor.layers import SEModule, SimpleSelfAttention
5+
from .parameters import ids_fn
76

7+
# from model_constructor.layers import SEModule, SimpleSelfAttention
88

99
bs_test = 4
1010

@@ -24,19 +24,6 @@
2424
)
2525

2626

27-
def value_name(value) -> str: # pragma: no cover
28-
name = getattr(value, "__name__", None)
29-
if name is not None:
30-
return name
31-
if isinstance(value, nn.Module):
32-
return value._get_name() # pylint: disable=W0212
33-
return value
34-
35-
36-
def ids_fn(key, value):
37-
return [f"{key[:2]}_{value_name(v)}" for v in value]
38-
39-
4027
def pytest_generate_tests(metafunc):
4128
for key, value in params.items():
4229
if key in metafunc.fixturenames:

tests/test_block.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# import pytest
22
from functools import partial
3+
34
import torch
4-
import torch.nn as nn
5-
from model_constructor.layers import SEModule, SimpleSelfAttention
5+
from torch import nn
66

7+
from model_constructor.layers import SEModule, SimpleSelfAttention
78
from model_constructor.model_constructor import ResBlock
89
from model_constructor.yaresnet import YaResBlock
910

11+
from .parameters import ids_fn
12+
1013
bs_test = 4
1114
img_size = 16
1215

@@ -23,19 +26,6 @@
2326
)
2427

2528

26-
def value_name(value) -> str:
27-
name = getattr(value, "__name__", None)
28-
if name is not None:
29-
return name
30-
if isinstance(value, nn.Module):
31-
return value._get_name() # pylint: disable=W0212
32-
return value
33-
34-
35-
def ids_fn(key, value):
36-
return [f"{key[:2]}_{value_name(v)}" for v in value]
37-
38-
3929
def pytest_generate_tests(metafunc):
4030
for key, value in params.items():
4131
if key in metafunc.fixturenames:

tests/test_convmixer.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
2-
import torch.nn as nn
32

43
from model_constructor.convmixer import ConvMixer, ConvMixerOriginal
54

5+
from .parameters import ids_fn
6+
67
bs_test = 4
78
img_size = 16
89

@@ -13,19 +14,6 @@
1314
)
1415

1516

16-
def value_name(value) -> str: # pragma: no cover
17-
name = getattr(value, "__name__", None)
18-
if name is not None:
19-
return name
20-
if isinstance(value, nn.Module):
21-
return value._get_name() # pylint: disable=W0212
22-
return value
23-
24-
25-
def ids_fn(key, value):
26-
return [f"{key[:2]}_{value_name(v)}" for v in value]
27-
28-
2917
def pytest_generate_tests(metafunc):
3018
for key, value in params.items():
3119
if key in metafunc.fixturenames:

tests/test_layers.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
import torch
2-
import torch.nn as nn
3-
4-
from model_constructor.layers import (
5-
ConvBnAct,
6-
Flatten,
7-
Noop,
8-
SEModule,
9-
SEModuleConv,
10-
SimpleSelfAttention,
11-
noop,
12-
)
132

3+
from model_constructor.layers import (ConvBnAct, Flatten, Noop, SEModule,
4+
SEModuleConv, SimpleSelfAttention, noop)
5+
6+
from .parameters import ids_fn
147

158
bs_test = 4
169

@@ -37,19 +30,6 @@
3730
)
3831

3932

40-
def value_name(value) -> str:
41-
name = getattr(value, "__name__", None)
42-
if name is not None:
43-
return name
44-
if isinstance(value, nn.Module): # pragma: no cover
45-
return value._get_name() # pylint: disable=W0212
46-
return value
47-
48-
49-
def ids_fn(key, value):
50-
return [f"{key[:2]}_{value_name(v)}" for v in value]
51-
52-
5333
def pytest_generate_tests(metafunc):
5434
for key, value in params.items():
5535
if key in metafunc.fixturenames:

tests/test_layers_depr.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# old (deprecated layers)
22
import torch
3-
import torch.nn as nn
43

54
from model_constructor.layers import ConvLayer, SEBlock, SEBlockConv
65

6+
from .parameters import ids_fn
77

88
bs_test = 4
99

@@ -28,19 +28,6 @@
2828
)
2929

3030

31-
def value_name(value) -> str: # pragma: no cover
32-
name = getattr(value, "__name__", None)
33-
if name is not None:
34-
return name
35-
if isinstance(value, nn.Module):
36-
return value._get_name() # pylint: disable=W0212
37-
return value
38-
39-
40-
def ids_fn(key, value):
41-
return [f"{key[:2]}_{value_name(v)}" for v in value]
42-
43-
4431
def pytest_generate_tests(metafunc):
4532
for key, value in params.items():
4633
if key in metafunc.fixturenames:

tests/test_mc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22

33
from model_constructor import ModelConstructor
4-
from model_constructor.layers import SEModule, SEModuleConv, SimpleSelfAttention
5-
4+
from model_constructor.layers import (SEModule, SEModuleConv,
5+
SimpleSelfAttention)
66

77
bs_test = 4
88

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import torch
3-
import torch.nn as nn
3+
from torch import nn
44

55
from model_constructor.model_constructor import XResNet34, XResNet50
66
from model_constructor.yaresnet import YaResNet34, YaResNet50

tests/test_models_old.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import torch
3-
import torch.nn as nn
3+
from torch import nn
44

55
from model_constructor.mxresnet import mxresnet34, mxresnet50
66
from model_constructor.xresnet import xresnet18, xresnet34, xresnet50

0 commit comments

Comments
 (0)