11from collections import OrderedDict
22from functools import partial
3- from typing import Callable , Union
3+ from typing import Callable , List , Sequence , Union
44
55import torch .nn as nn
66
@@ -24,14 +24,25 @@ def init_cnn(module: nn.Module):
2424
2525
2626class ResBlock (nn .Module ):
27- '''Resnet block'''
28-
29- def __init__ (self , expansion , in_channels , mid_channels , stride = 1 ,
30- conv_layer = ConvBnAct , act_fn = act_fn , zero_bn = True , bn_1st = True ,
31- groups = 1 , dw = False , div_groups = None ,
32- pool = None ,
33- se = None , sa = None
34- ):
27+ '''Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.'''
28+
29+ def __init__ (
30+ self ,
31+ expansion : int ,
32+ in_channels : int ,
33+ mid_channels : int ,
34+ stride : int = 1 ,
35+ conv_layer : Union [nn .Module , nn .Sequential ] = ConvBnAct ,
36+ act_fn : nn .Module = act_fn ,
37+ zero_bn : bool = True ,
38+ bn_1st : bool = True ,
39+ groups : int = 1 ,
40+ dw : bool = False ,
41+ div_groups : Union [None , int ] = None ,
42+ pool : Union [nn .Module , None ] = None ,
43+ se : Union [nn .Module , None ] = None ,
44+ sa : Union [nn .Module , None ] = None ,
45+ ):
3546 super ().__init__ ()
3647 # pool defined at ModelConstructor.
3748 out_channels , in_channels = mid_channels * expansion , in_channels * expansion
@@ -124,28 +135,38 @@ def _make_head(self):
124135
125136class ModelConstructor ():
126137 """Model constructor. As default - xresnet18"""
127- def __init__ (self , name = 'MC' , in_chans = 3 , num_classes = 1000 ,
128- block = ResBlock , conv_layer = ConvBnAct ,
129- block_sizes = [64 , 128 , 256 , 512 ], layers = [2 , 2 , 2 , 2 ],
130- norm = nn .BatchNorm2d ,
131- act_fn = nn .ReLU (inplace = True ),
132- pool = nn .AvgPool2d (2 , ceil_mode = True ),
133- expansion = 1 , groups = 1 , dw = False , div_groups = None ,
134- sa : Union [bool , int , Callable ] = False ,
135- se : Union [bool , int , Callable ] = False ,
136- se_module = None , se_reduction = None ,
137- bn_1st = True ,
138- zero_bn = True ,
139- stem_stride_on = 0 ,
140- stem_sizes = [32 , 32 , 64 ],
141- stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
142- stem_bn_end = False ,
143- _init_cnn = init_cnn ,
144- _make_stem = _make_stem ,
145- _make_layer = _make_layer ,
146- _make_body = _make_body ,
147- _make_head = _make_head ,
148- ):
138+ def __init__ (
139+ self ,
140+ name : str = 'MC' ,
141+ in_chans : int = 3 ,
142+ num_classes : int = 1000 ,
143+ block = ResBlock ,
144+ conv_layer = ConvBnAct ,
145+ block_sizes : List [int ] = [64 , 128 , 256 , 512 ],
146+ layers : List [int ] = [2 , 2 , 2 , 2 ],
147+ norm : nn .Module = nn .BatchNorm2d ,
148+ act_fn : nn .Module = nn .ReLU (inplace = True ),
149+ pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True ),
150+ expansion : int = 1 ,
151+ groups : int = 1 ,
152+ dw : bool = False ,
153+ div_groups = None ,
154+ sa : Union [bool , int , Callable ] = False ,
155+ se : Union [bool , int , Callable ] = False ,
156+ se_module = None ,
157+ se_reduction = None ,
158+ bn_1st = True ,
159+ zero_bn = True ,
160+ stem_stride_on = 0 ,
161+ stem_sizes = [32 , 32 , 64 ],
162+ stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
163+ stem_bn_end = False ,
164+ _init_cnn = init_cnn ,
165+ _make_stem = _make_stem ,
166+ _make_layer = _make_layer ,
167+ _make_body = _make_body ,
168+ _make_head = _make_head ,
169+ ):
149170 super ().__init__ ()
150171 # se can be bool, int (0, 1) or nn.Module
151172 # se_module - deprecated. Leaved for warning and checks.
0 commit comments