Skip to content

Commit d3a077c

Browse files
committed
black
1 parent c3a2086 commit d3a077c

File tree

1 file changed

+87
-48
lines changed

1 file changed

+87
-48
lines changed

src/model_constructor/convmixer.py

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,28 @@ def forward(self, x):
1717

1818

1919
# As original version, act_fn as argument.
20-
def ConvMixerOriginal(dim, depth,
21-
kernel_size=9, patch_size=7, n_classes=1000,
22-
act_fn=nn.GELU()):
20+
def ConvMixerOriginal(
21+
dim, depth, kernel_size=9, patch_size=7, n_classes=1000, act_fn=nn.GELU()
22+
):
2323
return nn.Sequential(
2424
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
2525
act_fn,
2626
nn.BatchNorm2d(dim),
27-
*[nn.Sequential(
28-
Residual(nn.Sequential(
29-
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
27+
*[
28+
nn.Sequential(
29+
Residual(
30+
nn.Sequential(
31+
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
32+
act_fn,
33+
nn.BatchNorm2d(dim),
34+
)
35+
),
36+
nn.Conv2d(dim, dim, kernel_size=1),
3037
act_fn,
31-
nn.BatchNorm2d(dim)
32-
)),
33-
nn.Conv2d(dim, dim, kernel_size=1),
34-
act_fn,
35-
nn.BatchNorm2d(dim)
36-
) for _i in range(depth)],
38+
nn.BatchNorm2d(dim),
39+
)
40+
for _i in range(depth)
41+
],
3742
nn.AdaptiveAvgPool2d((1, 1)),
3843
nn.Flatten(),
3944
nn.Linear(dim, n_classes)
@@ -44,24 +49,32 @@ class ConvLayer(nn.Sequential):
4449
"""Basic conv layers block"""
4550

4651
def __init__(
47-
self,
48-
in_channels: int,
49-
out_channels: int,
50-
kernel_size: Union[int, tuple[int, int]],
51-
stride: int = 1,
52-
act_fn: nn.Module = nn.GELU(),
53-
padding: Union[int, str] = 0,
54-
groups: int = 1,
55-
bn_1st: bool = False,
56-
pre_act: bool = False,
52+
self,
53+
in_channels: int,
54+
out_channels: int,
55+
kernel_size: Union[int, tuple[int, int]],
56+
stride: int = 1,
57+
act_fn: nn.Module = nn.GELU(),
58+
padding: Union[int, str] = 0,
59+
groups: int = 1,
60+
bn_1st: bool = False,
61+
pre_act: bool = False,
5762
):
5863

59-
conv_layer = [('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
60-
padding=padding, groups=groups))]
61-
act_bn = [
62-
('act_fn', act_fn),
63-
('bn', nn.BatchNorm2d(out_channels))
64+
conv_layer = [
65+
(
66+
"conv",
67+
nn.Conv2d(
68+
in_channels,
69+
out_channels,
70+
kernel_size,
71+
stride=stride,
72+
padding=padding,
73+
groups=groups,
74+
),
75+
)
6476
]
77+
act_bn = [("act_fn", act_fn), ("bn", nn.BatchNorm2d(out_channels))]
6578
if bn_1st:
6679
act_bn.reverse()
6780
if pre_act:
@@ -73,20 +86,19 @@ def __init__(
7386

7487

7588
class ConvMixer(nn.Sequential):
76-
7789
def __init__(
78-
self,
79-
dim: int,
80-
depth: int,
81-
kernel_size: int = 9,
82-
patch_size: int = 7,
83-
n_classes: int = 1000,
84-
act_fn: nn.Module = nn.GELU(),
85-
stem: Optional[nn.Module] = None,
86-
in_chans: int = 3,
87-
bn_1st: bool = False,
88-
pre_act: bool = False,
89-
init_func: Optional[Callable[[nn.Module], None]] = None
90+
self,
91+
dim: int,
92+
depth: int,
93+
kernel_size: int = 9,
94+
patch_size: int = 7,
95+
n_classes: int = 1000,
96+
act_fn: nn.Module = nn.GELU(),
97+
stem: Optional[nn.Module] = None,
98+
in_chans: int = 3,
99+
bn_1st: bool = False,
100+
pre_act: bool = False,
101+
init_func: Optional[Callable[[nn.Module], None]] = None,
90102
):
91103
"""ConvMixer constructor.
92104
Adopted from https://github.com/tmp-iclr/convmixer
@@ -108,18 +120,45 @@ def __init__(
108120
if pre_act:
109121
bn_1st = False
110122
if stem is None:
111-
stem = ConvLayer(in_chans, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)
123+
stem = ConvLayer(
124+
in_chans,
125+
dim,
126+
kernel_size=patch_size,
127+
stride=patch_size,
128+
act_fn=act_fn,
129+
bn_1st=bn_1st,
130+
)
112131

113132
super().__init__(
114133
stem,
115-
*[nn.Sequential(
116-
Residual(
117-
ConvLayer(dim, dim, kernel_size, act_fn=act_fn,
118-
groups=dim, padding="same", bn_1st=bn_1st, pre_act=pre_act)),
119-
ConvLayer(dim, dim, kernel_size=1, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act))
120-
for _ in range(depth)],
134+
*[
135+
nn.Sequential(
136+
Residual(
137+
ConvLayer(
138+
dim,
139+
dim,
140+
kernel_size,
141+
act_fn=act_fn,
142+
groups=dim,
143+
padding="same",
144+
bn_1st=bn_1st,
145+
pre_act=pre_act,
146+
)
147+
),
148+
ConvLayer(
149+
dim,
150+
dim,
151+
kernel_size=1,
152+
act_fn=act_fn,
153+
bn_1st=bn_1st,
154+
pre_act=pre_act,
155+
),
156+
)
157+
for _ in range(depth)
158+
],
121159
nn.AdaptiveAvgPool2d((1, 1)),
122160
nn.Flatten(),
123-
nn.Linear(dim, n_classes))
161+
nn.Linear(dim, n_classes)
162+
)
124163
if init_func is not None: # pragma: no cover
125164
init_func(self)

0 commit comments

Comments
 (0)