@@ -17,23 +17,28 @@ def forward(self, x):
1717
1818
1919# As original version, act_fn as argument.
20- def ConvMixerOriginal (dim , depth ,
21- kernel_size = 9 , patch_size = 7 , n_classes = 1000 ,
22- act_fn = nn . GELU () ):
20+ def ConvMixerOriginal (
21+ dim , depth , kernel_size = 9 , patch_size = 7 , n_classes = 1000 , act_fn = nn . GELU ()
22+ ):
2323 return nn .Sequential (
2424 nn .Conv2d (3 , dim , kernel_size = patch_size , stride = patch_size ),
2525 act_fn ,
2626 nn .BatchNorm2d (dim ),
27- * [nn .Sequential (
28- Residual (nn .Sequential (
29- nn .Conv2d (dim , dim , kernel_size , groups = dim , padding = "same" ),
27+ * [
28+ nn .Sequential (
29+ Residual (
30+ nn .Sequential (
31+ nn .Conv2d (dim , dim , kernel_size , groups = dim , padding = "same" ),
32+ act_fn ,
33+ nn .BatchNorm2d (dim ),
34+ )
35+ ),
36+ nn .Conv2d (dim , dim , kernel_size = 1 ),
3037 act_fn ,
31- nn .BatchNorm2d (dim )
32- )),
33- nn .Conv2d (dim , dim , kernel_size = 1 ),
34- act_fn ,
35- nn .BatchNorm2d (dim )
36- ) for _i in range (depth )],
38+ nn .BatchNorm2d (dim ),
39+ )
40+ for _i in range (depth )
41+ ],
3742 nn .AdaptiveAvgPool2d ((1 , 1 )),
3843 nn .Flatten (),
3944 nn .Linear (dim , n_classes )
@@ -44,24 +49,32 @@ class ConvLayer(nn.Sequential):
4449 """Basic conv layers block"""
4550
4651 def __init__ (
47- self ,
48- in_channels : int ,
49- out_channels : int ,
50- kernel_size : Union [int , tuple [int , int ]],
51- stride : int = 1 ,
52- act_fn : nn .Module = nn .GELU (),
53- padding : Union [int , str ] = 0 ,
54- groups : int = 1 ,
55- bn_1st : bool = False ,
56- pre_act : bool = False ,
52+ self ,
53+ in_channels : int ,
54+ out_channels : int ,
55+ kernel_size : Union [int , tuple [int , int ]],
56+ stride : int = 1 ,
57+ act_fn : nn .Module = nn .GELU (),
58+ padding : Union [int , str ] = 0 ,
59+ groups : int = 1 ,
60+ bn_1st : bool = False ,
61+ pre_act : bool = False ,
5762 ):
5863
59- conv_layer = [('conv' , nn .Conv2d (in_channels , out_channels , kernel_size , stride = stride ,
60- padding = padding , groups = groups ))]
61- act_bn = [
62- ('act_fn' , act_fn ),
63- ('bn' , nn .BatchNorm2d (out_channels ))
64+ conv_layer = [
65+ (
66+ "conv" ,
67+ nn .Conv2d (
68+ in_channels ,
69+ out_channels ,
70+ kernel_size ,
71+ stride = stride ,
72+ padding = padding ,
73+ groups = groups ,
74+ ),
75+ )
6476 ]
77+ act_bn = [("act_fn" , act_fn ), ("bn" , nn .BatchNorm2d (out_channels ))]
6578 if bn_1st :
6679 act_bn .reverse ()
6780 if pre_act :
@@ -73,20 +86,19 @@ def __init__(
7386
7487
7588class ConvMixer (nn .Sequential ):
76-
7789 def __init__ (
78- self ,
79- dim : int ,
80- depth : int ,
81- kernel_size : int = 9 ,
82- patch_size : int = 7 ,
83- n_classes : int = 1000 ,
84- act_fn : nn .Module = nn .GELU (),
85- stem : Optional [nn .Module ] = None ,
86- in_chans : int = 3 ,
87- bn_1st : bool = False ,
88- pre_act : bool = False ,
89- init_func : Optional [Callable [[nn .Module ], None ]] = None
90+ self ,
91+ dim : int ,
92+ depth : int ,
93+ kernel_size : int = 9 ,
94+ patch_size : int = 7 ,
95+ n_classes : int = 1000 ,
96+ act_fn : nn .Module = nn .GELU (),
97+ stem : Optional [nn .Module ] = None ,
98+ in_chans : int = 3 ,
99+ bn_1st : bool = False ,
100+ pre_act : bool = False ,
101+ init_func : Optional [Callable [[nn .Module ], None ]] = None ,
90102 ):
91103 """ConvMixer constructor.
92104 Adopted from https://github.com/tmp-iclr/convmixer
@@ -108,18 +120,45 @@ def __init__(
108120 if pre_act :
109121 bn_1st = False
110122 if stem is None :
111- stem = ConvLayer (in_chans , dim , kernel_size = patch_size , stride = patch_size , act_fn = act_fn , bn_1st = bn_1st )
123+ stem = ConvLayer (
124+ in_chans ,
125+ dim ,
126+ kernel_size = patch_size ,
127+ stride = patch_size ,
128+ act_fn = act_fn ,
129+ bn_1st = bn_1st ,
130+ )
112131
113132 super ().__init__ (
114133 stem ,
115- * [nn .Sequential (
116- Residual (
117- ConvLayer (dim , dim , kernel_size , act_fn = act_fn ,
118- groups = dim , padding = "same" , bn_1st = bn_1st , pre_act = pre_act )),
119- ConvLayer (dim , dim , kernel_size = 1 , act_fn = act_fn , bn_1st = bn_1st , pre_act = pre_act ))
120- for _ in range (depth )],
134+ * [
135+ nn .Sequential (
136+ Residual (
137+ ConvLayer (
138+ dim ,
139+ dim ,
140+ kernel_size ,
141+ act_fn = act_fn ,
142+ groups = dim ,
143+ padding = "same" ,
144+ bn_1st = bn_1st ,
145+ pre_act = pre_act ,
146+ )
147+ ),
148+ ConvLayer (
149+ dim ,
150+ dim ,
151+ kernel_size = 1 ,
152+ act_fn = act_fn ,
153+ bn_1st = bn_1st ,
154+ pre_act = pre_act ,
155+ ),
156+ )
157+ for _ in range (depth )
158+ ],
121159 nn .AdaptiveAvgPool2d ((1 , 1 )),
122160 nn .Flatten (),
123- nn .Linear (dim , n_classes ))
161+ nn .Linear (dim , n_classes )
162+ )
124163 if init_func is not None : # pragma: no cover
125164 init_func (self )
0 commit comments