33# Adopted from https://github.com/tmp-iclr/convmixer
44# Home for convmixer: https://github.com/locuslab/convmixer
55from collections import OrderedDict
6- from typing import Callable
6+ from typing import Callable , Optional
77import torch .nn as nn
88
99
@@ -43,9 +43,18 @@ def ConvMixerOriginal(dim, depth,
4343class ConvLayer (nn .Sequential ):
4444 """Basic conv layers block"""
4545
46- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
47- act_fn = nn .GELU (), padding = 0 , groups = 1 ,
48- bn_1st = False , pre_act = False ):
46+ def __init__ (
47+ self ,
48+ in_channels ,
49+ out_channels ,
50+ kernel_size ,
51+ stride = 1 ,
52+ act_fn = nn .GELU (),
53+ padding = 0 ,
54+ groups = 1 ,
55+ bn_1st = False ,
56+ pre_act = False
57+ ):
4958
5059 conv_layer = [('conv' , nn .Conv2d (in_channels , out_channels , kernel_size , stride = stride ,
5160 padding = padding , groups = groups ))]
@@ -65,12 +74,20 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
6574
6675class ConvMixer (nn .Sequential ):
6776
68- def __init__ (self , dim : int , depth : int ,
69- kernel_size : int = 9 , patch_size : int = 7 , n_classes : int = 1000 ,
70- act_fn : nn .Module = nn .GELU (),
71- stem : nn .Module = None ,
72- bn_1st : bool = False , pre_act : bool = False ,
73- init_func : Callable = None ):
77+ def __init__ (
78+ self ,
79+ dim : int ,
80+ depth : int ,
81+ kernel_size : int = 9 ,
82+ patch_size : int = 7 ,
83+ n_classes : int = 1000 ,
84+ act_fn : nn .Module = nn .GELU (),
85+ stem : Optional [nn .Module ] = None ,
86+ in_chans : int = 3 ,
87+ bn_1st : bool = False ,
88+ pre_act : bool = False ,
89+ init_func : Optional [Callable ] = None
90+ ):
7491 """ConvMixer constructor.
7592 Adopted from https://github.com/tmp-iclr/convmixer
7693
@@ -91,7 +108,7 @@ def __init__(self, dim: int, depth: int,
91108 if pre_act :
92109 bn_1st = False
93110 if stem is None :
94- stem = ConvLayer (3 , dim , kernel_size = patch_size , stride = patch_size , act_fn = act_fn , bn_1st = bn_1st )
111+ stem = ConvLayer (in_chans , dim , kernel_size = patch_size , stride = patch_size , act_fn = act_fn , bn_1st = bn_1st )
95112
96113 super ().__init__ (
97114 stem ,
@@ -100,7 +117,7 @@ def __init__(self, dim: int, depth: int,
100117 ConvLayer (dim , dim , kernel_size , act_fn = act_fn ,
101118 groups = dim , padding = "same" , bn_1st = bn_1st , pre_act = pre_act )),
102119 ConvLayer (dim , dim , kernel_size = 1 , act_fn = act_fn , bn_1st = bn_1st , pre_act = pre_act ))
103- for i in range (depth )],
120+ for _ in range (depth )],
104121 nn .AdaptiveAvgPool2d ((1 , 1 )),
105122 nn .Flatten (),
106123 nn .Linear (dim , n_classes ))
0 commit comments