33# Adopted from https://github.com/tmp-iclr/convmixer
44# Home for convmixer: https://github.com/locuslab/convmixer
55from collections import OrderedDict
6- from typing import Callable , Optional
6+ from typing import Callable , Optional , Union
77import torch .nn as nn
88
99
@@ -33,7 +33,7 @@ def ConvMixerOriginal(dim, depth,
3333 nn .Conv2d (dim , dim , kernel_size = 1 ),
3434 act_fn ,
3535 nn .BatchNorm2d (dim )
36- ) for i in range (depth )],
36+ ) for _i in range (depth )],
3737 nn .AdaptiveAvgPool2d ((1 , 1 )),
3838 nn .Flatten (),
3939 nn .Linear (dim , n_classes )
@@ -45,16 +45,16 @@ class ConvLayer(nn.Sequential):
4545
4646 def __init__ (
4747 self ,
48- in_channels ,
49- out_channels ,
50- kernel_size ,
51- stride = 1 ,
52- act_fn = nn .GELU (),
53- padding = 0 ,
54- groups = 1 ,
55- bn_1st = False ,
56- pre_act = False
57- ):
48+ in_channels : int ,
49+ out_channels : int ,
50+ kernel_size : Union [ int , tuple [ int , int ]] ,
51+ stride : int = 1 ,
52+ act_fn : nn . Module = nn .GELU (),
53+ padding : Union [ int , str ] = 0 ,
54+ groups : int = 1 ,
55+ bn_1st : bool = False ,
56+ pre_act : bool = False ,
57+ ):
5858
5959 conv_layer = [('conv' , nn .Conv2d (in_channels , out_channels , kernel_size , stride = stride ,
6060 padding = padding , groups = groups ))]
@@ -86,8 +86,8 @@ def __init__(
8686 in_chans : int = 3 ,
8787 bn_1st : bool = False ,
8888 pre_act : bool = False ,
89- init_func : Optional [Callable ] = None
90- ):
89+ init_func : Optional [Callable [[ nn . Module ], None ] ] = None
90+ ):
9191 """ConvMixer constructor.
9292 Adopted from https://github.com/tmp-iclr/convmixer
9393
0 commit comments