Skip to content

Commit c6d4a42

Browse files
committed
mxresnet, tests
1 parent e3d5b6b commit c6d4a42

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

src/model_constructor/mxresnet.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from torch import nn
2+
3+
from .xresnet import XResNet, XResNet34, XResNet50
4+
5+
6+
class MxResNet(XResNet):
7+
stem_sizes: list[int] = [3, 32, 64, 64]
8+
act_fn: type[nn.Module] = nn.Mish
9+
10+
class MxResNet34(XResNet34):
11+
stem_sizes: list[int] = [3, 32, 64, 64]
12+
act_fn: type[nn.Module] = nn.Mish
13+
14+
15+
class MxResNet50(XResNet50):
16+
stem_sizes: list[int] = [3, 32, 64, 64]
17+
act_fn: type[nn.Module] = nn.Mish

src/model_constructor/yaresnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88
from torch import nn
9-
from torch.nn import Mish
109

1110
from model_constructor.helpers import nn_seq
1211

@@ -207,7 +206,7 @@ class YaResNet(ModelConstructor):
207206
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = xresnet_stem
208207
stem_sizes: list[int] = [3, 32, 64, 64]
209208
block: type[nn.Module] = YaBasicBlock
210-
act_fn: type[nn.Module] = Mish
209+
act_fn: type[nn.Module] = nn.Mish
211210
pool: Optional[Callable[[Any], nn.Module]] = partial(
212211
nn.AvgPool2d, kernel_size=2, ceil_mode=True
213212
)

tests/test_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from torch import nn
44

55
from model_constructor.model_constructor import ModelConstructor, ResNet34, ResNet50
6-
from model_constructor.yaresnet import YaResNet, YaResNet34, YaResNet50
6+
from model_constructor.mxresnet import MxResNet, MxResNet34, MxResNet50
77
from model_constructor.xresnet import XResNet, XResNet34, XResNet50
8+
from model_constructor.yaresnet import YaResNet, YaResNet34, YaResNet50
89

910
bs_test = 2
1011
img_size = 16
@@ -20,6 +21,9 @@
2021
YaResNet,
2122
YaResNet34,
2223
YaResNet50,
24+
MxResNet,
25+
MxResNet34,
26+
MxResNet50,
2327
]
2428
act_fn_list = [
2529
nn.ReLU,

0 commit comments

Comments
 (0)