|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +from model_constructor.layers import ConvBnAct, Flatten, Noop, SEModule, SEModuleConv, SimpleSelfAttention, noop |
| 5 | + |
| 6 | + |
| 7 | +bs_test = 4 |
| 8 | + |
| 9 | + |
| 10 | +params = dict( |
| 11 | + kernel_size=[3, 1], |
| 12 | + stride=[1, 2], |
| 13 | + padding=[None, 1], |
| 14 | + bias=[False, True], |
| 15 | + groups=[1, 2], |
| 16 | + # # act_fn=act_fn, |
| 17 | + pre_act=[False, True], |
| 18 | + bn_layer=[True, False], |
| 19 | + bn_1st=[True, False], |
| 20 | + zero_bn=[False, True], |
| 21 | + # SA |
| 22 | + sym=[False, True], |
| 23 | + # SE |
| 24 | + se_module=[SEModule, SEModuleConv], |
| 25 | + reduction=[16, 2], |
| 26 | + rd_channels=[None, 2], |
| 27 | + rd_max=[False, True], |
| 28 | + use_bias=[True, False], |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +def value_name(value) -> str: |
| 33 | + name = getattr(value, "__name__", None) |
| 34 | + if name is not None: |
| 35 | + return name |
| 36 | + if isinstance(value, nn.Module): |
| 37 | + return value._get_name() |
| 38 | + else: |
| 39 | + return value |
| 40 | + |
| 41 | + |
| 42 | +def ids_fn(key, value): |
| 43 | + return [f"{key[:2]}_{value_name(v)}" for v in value] |
| 44 | + |
| 45 | + |
| 46 | +def pytest_generate_tests(metafunc): |
| 47 | + for key, value in params.items(): |
| 48 | + if key in metafunc.fixturenames: |
| 49 | + metafunc.parametrize(key, value, ids=ids_fn(key, value)) |
| 50 | + |
| 51 | + |
| 52 | +def test_Flatten(): |
| 53 | + """test Flatten""" |
| 54 | + flatten = Flatten() |
| 55 | + channels = 4 |
| 56 | + xb = torch.randn(bs_test, channels, channels) |
| 57 | + out = flatten(xb) |
| 58 | + assert out.shape == torch.Size([bs_test, channels * channels]) |
| 59 | + |
| 60 | + |
| 61 | +def test_noop(): |
| 62 | + """test Noop, noop""" |
| 63 | + xb = torch.randn(bs_test) |
| 64 | + xb_copy = xb.clone().detach() |
| 65 | + out = noop(xb) |
| 66 | + assert out is xb |
| 67 | + assert all(out.eq(xb_copy)) |
| 68 | + noop_module = Noop() |
| 69 | + out = noop_module(xb) |
| 70 | + assert out is xb |
| 71 | + assert all(out.eq(xb_copy)) |
| 72 | + |
| 73 | + |
| 74 | +def test_ConvBnAct(kernel_size, stride, bias, groups, pre_act, bn_layer, bn_1st, zero_bn): |
| 75 | + """test ConvBnAct""" |
| 76 | + in_channels = out_channels = 4 |
| 77 | + channel_size = 4 |
| 78 | + block = ConvBnAct( |
| 79 | + in_channels, out_channels, kernel_size, stride, |
| 80 | + padding=None, bias=bias, groups=groups, |
| 81 | + pre_act=pre_act, bn_layer=bn_layer, bn_1st=bn_1st, zero_bn=zero_bn) |
| 82 | + xb = torch.randn(bs_test, in_channels, channel_size, channel_size) |
| 83 | + out = block(xb) |
| 84 | + out_size = channel_size |
| 85 | + if stride == 2: |
| 86 | + out_size = channel_size // stride |
| 87 | + assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size]) |
| 88 | + |
| 89 | + |
| 90 | +def test_SimpleSelfAttention(sym): |
| 91 | + """test SimpleSelfAttention""" |
| 92 | + in_channels = 4 |
| 93 | + kernel_size = 1 # ? can be 3? if so check sym hack. |
| 94 | + channel_size = 4 |
| 95 | + sa = SimpleSelfAttention(in_channels, kernel_size, sym) |
| 96 | + xb = torch.randn(bs_test, in_channels, channel_size, channel_size) |
| 97 | + out = sa(xb) |
| 98 | + assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size]) |
| 99 | + |
| 100 | + |
| 101 | +def test_SE(se_module, reduction, rd_channels, rd_max, use_bias): |
| 102 | + """test SE""" |
| 103 | + in_channels = 8 |
| 104 | + channel_size = 4 |
| 105 | + se = se_module(in_channels, reduction, rd_channels, rd_max, use_bias=use_bias) |
| 106 | + xb = torch.randn(bs_test, in_channels, channel_size, channel_size) |
| 107 | + out = se(xb) |
| 108 | + assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size]) |
0 commit comments