33# Adopted from https://github.com/tmp-iclr/convmixer
44# Home for convmixer: https://github.com/locuslab/convmixer
55from collections import OrderedDict
6- from typing import Callable
6+ from typing import Callable , List , Optional , Union
7+
78import torch .nn as nn
9+ from torch import TensorType
810
911
1012class Residual (nn .Module ):
11- def __init__ (self , fn ):
13+ def __init__ (self , fn : Callable [[ TensorType ], TensorType ] ):
1214 super ().__init__ ()
1315 self .fn = fn
1416
15- def forward (self , x ) :
17+ def forward (self , x : TensorType ) -> TensorType :
1618 return self .fn (x ) + x
1719
1820
1921# As original version, act_fn as argument.
20- def ConvMixerOriginal (dim , depth ,
21- kernel_size = 9 , patch_size = 7 , n_classes = 1000 ,
22- act_fn = nn .GELU ()):
22+ def ConvMixerOriginal (
23+ dim : int ,
24+ depth : int ,
25+ kernel_size : int = 9 ,
26+ patch_size : int = 7 ,
27+ n_classes : int = 1000 ,
28+ act_fn : nn .Module = nn .GELU (),
29+ ):
2330 return nn .Sequential (
2431 nn .Conv2d (3 , dim , kernel_size = patch_size , stride = patch_size ),
2532 act_fn ,
2633 nn .BatchNorm2d (dim ),
27- * [nn .Sequential (
28- Residual (nn .Sequential (
29- nn .Conv2d (dim , dim , kernel_size , groups = dim , padding = "same" ),
34+ * [
35+ nn .Sequential (
36+ Residual (
37+ nn .Sequential (
38+ nn .Conv2d (dim , dim , kernel_size , groups = dim , padding = "same" ),
39+ act_fn ,
40+ nn .BatchNorm2d (dim ),
41+ )
42+ ),
43+ nn .Conv2d (dim , dim , kernel_size = 1 ),
3044 act_fn ,
31- nn .BatchNorm2d (dim )
32- )),
33- nn .Conv2d (dim , dim , kernel_size = 1 ),
34- act_fn ,
35- nn .BatchNorm2d (dim )
36- ) for i in range (depth )],
45+ nn .BatchNorm2d (dim ),
46+ )
47+ for _i in range (depth )
48+ ],
3749 nn .AdaptiveAvgPool2d ((1 , 1 )),
3850 nn .Flatten (),
3951 nn .Linear (dim , n_classes )
@@ -43,15 +55,35 @@ def ConvMixerOriginal(dim, depth,
4355class ConvLayer (nn .Sequential ):
4456 """Basic conv layers block"""
4557
46- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
47- act_fn = nn .GELU (), padding = 0 , groups = 1 ,
48- bn_1st = False , pre_act = False ):
58+ def __init__ (
59+ self ,
60+ in_channels : int ,
61+ out_channels : int ,
62+ kernel_size : Union [int , tuple [int , int ]],
63+ stride : int = 1 ,
64+ act_fn : nn .Module = nn .GELU (),
65+ padding : Union [int , str ] = 0 ,
66+ groups : int = 1 ,
67+ bn_1st : bool = False ,
68+ pre_act : bool = False ,
69+ ):
4970
50- conv_layer = [('conv' , nn .Conv2d (in_channels , out_channels , kernel_size , stride = stride ,
51- padding = padding , groups = groups ))]
52- act_bn = [
53- ('act_fn' , act_fn ),
54- ('bn' , nn .BatchNorm2d (out_channels ))
71+ conv_layer : List [tuple [str , nn .Module ]] = [
72+ (
73+ "conv" ,
74+ nn .Conv2d (
75+ in_channels ,
76+ out_channels ,
77+ kernel_size ,
78+ stride = stride ,
79+ padding = padding ,
80+ groups = groups ,
81+ ),
82+ )
83+ ]
84+ act_bn : List [tuple [str , nn .Module ]] = [
85+ ("act_fn" , act_fn ),
86+ ("bn" , nn .BatchNorm2d (out_channels )),
5587 ]
5688 if bn_1st :
5789 act_bn .reverse ()
@@ -64,45 +96,79 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
6496
6597
6698class ConvMixer (nn .Sequential ):
67-
68- def __init__ (self , dim : int , depth : int ,
69- kernel_size : int = 9 , patch_size : int = 7 , n_classes : int = 1000 ,
70- act_fn : nn .Module = nn .GELU (),
71- stem : nn .Module = None ,
72- bn_1st : bool = False , pre_act : bool = False ,
73- init_func : Callable = None ):
99+ def __init__ (
100+ self ,
101+ dim : int ,
102+ depth : int ,
103+ kernel_size : int = 9 ,
104+ patch_size : int = 7 ,
105+ n_classes : int = 1000 ,
106+ act_fn : nn .Module = nn .GELU (),
107+ stem : Optional [nn .Module ] = None ,
108+ in_chans : int = 3 ,
109+ bn_1st : bool = False ,
110+ pre_act : bool = False ,
111+ init_func : Optional [Callable [[nn .Module ], None ]] = None ,
112+ ):
74113 """ConvMixer constructor.
75114 Adopted from https://github.com/tmp-iclr/convmixer
76115
77116 Args:
78- dim (int): Dimention of model.
117+ dim (int): Dimension of model.
79118 depth (int): Depth of model.
80119 kernel_size (int, optional): Kernel size. Defaults to 9.
81120 patch_size (int, optional): Patch size. Defaults to 7.
82121 n_classes (int, optional): Number of classes. Defaults to 1000.
83122 act_fn (nn.Module, optional): Activation function. Defaults to nn.GELU().
84123 stem (nn.Module, optional): You can path different first layer..
85- stem_ks (int, optional): If stem_ch not 0 - kernel size for adittional layer. Defaults to 1.
86- bn_1st (bool, optional): If True - BatchNorm befor activation function. Defaults to False.
87- pre_act (bool, optional): If True - activatin function befor convolution layer. Defaults to False.
124+ stem_ks (int, optional): If stem_ch not 0 - kernel size for additional layer. Defaults to 1.
125+ bn_1st (bool, optional): If True - BatchNorm before activation function. Defaults to False.
126+ pre_act (bool, optional): If True - activation function before convolution layer. Defaults to False.
88127 init_func (Callable, optional): External function for init model.
89128
90129 """
91130 if pre_act :
92131 bn_1st = False
93132 if stem is None :
94- stem = ConvLayer (3 , dim , kernel_size = patch_size , stride = patch_size , act_fn = act_fn , bn_1st = bn_1st )
133+ stem = ConvLayer (
134+ in_chans ,
135+ dim ,
136+ kernel_size = patch_size ,
137+ stride = patch_size ,
138+ act_fn = act_fn ,
139+ bn_1st = bn_1st ,
140+ )
95141
96142 super ().__init__ (
97143 stem ,
98- * [nn .Sequential (
99- Residual (
100- ConvLayer (dim , dim , kernel_size , act_fn = act_fn ,
101- groups = dim , padding = "same" , bn_1st = bn_1st , pre_act = pre_act )),
102- ConvLayer (dim , dim , kernel_size = 1 , act_fn = act_fn , bn_1st = bn_1st , pre_act = pre_act ))
103- for i in range (depth )],
144+ * [
145+ nn .Sequential (
146+ Residual (
147+ ConvLayer (
148+ dim ,
149+ dim ,
150+ kernel_size ,
151+ act_fn = act_fn ,
152+ groups = dim ,
153+ padding = "same" ,
154+ bn_1st = bn_1st ,
155+ pre_act = pre_act ,
156+ )
157+ ),
158+ ConvLayer (
159+ dim ,
160+ dim ,
161+ kernel_size = 1 ,
162+ act_fn = act_fn ,
163+ bn_1st = bn_1st ,
164+ pre_act = pre_act ,
165+ ),
166+ )
167+ for _ in range (depth )
168+ ],
104169 nn .AdaptiveAvgPool2d ((1 , 1 )),
105170 nn .Flatten (),
106- nn .Linear (dim , n_classes ))
171+ nn .Linear (dim , n_classes )
172+ )
107173 if init_func is not None : # pragma: no cover
108174 init_func (self )
0 commit comments