|
2 | 2 |
|
3 | 3 | import torch |
4 | 4 |
|
5 | | -from model_constructor.blocks import BottleneckBlock |
| 5 | +from model_constructor.blocks import BasicBlock, BottleneckBlock |
6 | 6 | from model_constructor.layers import SEModule, SEModuleConv, SimpleSelfAttention |
7 | | -from model_constructor.model_constructor import ModelConstructor |
| 7 | +from model_constructor.model_constructor import ModelCfg, ModelConstructor |
8 | 8 |
|
9 | 9 | bs_test = 4 |
10 | 10 | in_chans = 3 |
@@ -94,3 +94,46 @@ def test_MC_bottleneck(): |
94 | 94 | assert model.body.l_0.bl_0.convs.conv_0.conv.in_channels == 64 |
95 | 95 | assert model.body.l_0.bl_0.convs.conv_0.conv.out_channels == 128 |
96 | 96 | assert model.body.l_0.bl_1.convs.conv_0.conv.in_channels == 256 |
| 97 | + |
| 98 | + |
| 99 | +def test_ModelCfg(): |
| 100 | + """test ModelCfg""" |
| 101 | + # default - just create config with custom name |
| 102 | + cfg = ModelCfg(name="custom_name") |
| 103 | + repr_str = cfg.__repr__() |
| 104 | + assert repr_str.startswith("custom_name") |
| 105 | + # initiate from string |
| 106 | + cfg = ModelCfg(act_fn="torch.nn.Mish") |
| 107 | + assert cfg.act_fn is torch.nn.Mish |
| 108 | + # wrong name |
| 109 | + try: |
| 110 | + cfg = ModelCfg(act_fn="wrong_name") |
| 111 | + except ImportError as err: |
| 112 | + assert str(err) == "Module wrong_name not found at torch.nn" |
| 113 | + cfg = ModelCfg(act_fn="nn.Tanh") |
| 114 | + assert cfg.act_fn is torch.nn.Tanh |
| 115 | + cfg = ModelCfg(block="model_constructor.blocks.BottleneckBlock") |
| 116 | + assert cfg.block is BottleneckBlock |
| 117 | + |
| 118 | + |
| 119 | +def test_create_model_class_methods(): |
| 120 | + """test class methods ModelConstructor""" |
| 121 | + # create model |
| 122 | + model = ModelConstructor.create_model(act_fn="Mish", num_classes=10) |
| 123 | + assert str(model.body.l_0.bl_0.convs.conv_0.act_fn) == "Mish(inplace=True)" |
| 124 | + pred = model(xb) |
| 125 | + assert pred.shape == torch.Size([bs_test, 10]) |
| 126 | + # from cfg |
| 127 | + cfg = ModelCfg(block=BottleneckBlock, num_classes=10) |
| 128 | + mc = ModelConstructor.from_cfg(cfg) |
| 129 | + model = mc() |
| 130 | + assert isinstance(model.body.l_0.bl_0, BottleneckBlock) |
| 131 | + pred = model(xb) |
| 132 | + assert pred.shape == torch.Size([bs_test, 10]) |
| 133 | + |
| 134 | + cfg.block = BasicBlock |
| 135 | + cfg.num_classes = 2 |
| 136 | + model = ModelConstructor.create_model(cfg) |
| 137 | + assert isinstance(model.body.l_0.bl_0, BasicBlock) |
| 138 | + pred = model(xb) |
| 139 | + assert pred.shape == torch.Size([bs_test, 2]) |
0 commit comments