33# Adopted from https://github.com/tmp-iclr/convmixer
44# Home for convmixer: https://github.com/locuslab/convmixer
55from collections import OrderedDict
6- from typing import Callable , Optional , Union
6+ from typing import Callable , List , Optional , Union
7+
78import torch .nn as nn
9+ from torch import TensorType
810
911
1012class Residual (nn .Module ):
11- def __init__ (self , fn ):
13+ def __init__ (self , fn : Callable [[ TensorType ], TensorType ] ):
1214 super ().__init__ ()
1315 self .fn = fn
1416
15- def forward (self , x ) :
17+ def forward (self , x : TensorType ) -> TensorType :
1618 return self .fn (x ) + x
1719
1820
1921# As original version, act_fn as argument.
2022def ConvMixerOriginal (
21- dim , depth , kernel_size = 9 , patch_size = 7 , n_classes = 1000 , act_fn = nn .GELU ()
23+ dim : int ,
24+ depth : int ,
25+ kernel_size : int = 9 ,
26+ patch_size : int = 7 ,
27+ n_classes : int = 1000 ,
28+ act_fn : nn .Module = nn .GELU (),
2229):
2330 return nn .Sequential (
2431 nn .Conv2d (3 , dim , kernel_size = patch_size , stride = patch_size ),
@@ -61,7 +68,7 @@ def __init__(
6168 pre_act : bool = False ,
6269 ):
6370
64- conv_layer = [
71+ conv_layer : List [ tuple [ str , nn . Module ]] = [
6572 (
6673 "conv" ,
6774 nn .Conv2d (
@@ -74,7 +81,10 @@ def __init__(
7481 ),
7582 )
7683 ]
77- act_bn = [("act_fn" , act_fn ), ("bn" , nn .BatchNorm2d (out_channels ))]
84+ act_bn : List [tuple [str , nn .Module ]] = [
85+ ("act_fn" , act_fn ),
86+ ("bn" , nn .BatchNorm2d (out_channels )),
87+ ]
7888 if bn_1st :
7989 act_bn .reverse ()
8090 if pre_act :
0 commit comments