11# YaResBlock - former NewResBlock.
22# Yet another ResNet.
33
4- import torch .nn as nn
5- from functools import partial
64from collections import OrderedDict
7- from .layers import ConvBnAct
8- from .net import Net
5+ from functools import partial
6+ from typing import Union
7+
8+ import torch .nn as nn
99from torch .nn import Mish
1010
11+ from .layers import ConvBnAct
12+ from .net import Net
1113
12- __all__ = ['YaResBlock' , 'yaresnet_parameters' , 'yaresnet34' , 'yaresnet50' ]
14+ __all__ = [
15+ 'YaResBlock' ,
16+ # 'yaresnet_parameters',
17+ # 'yaresnet34',
18+ # 'yaresnet50',
19+ ]
1320
1421
1522act_fn = nn .ReLU (inplace = True )
1825class YaResBlock (nn .Module ):
1926 '''YaResBlock. Reduce by pool instead of stride 2'''
2027
21- def __init__ (self , expansion , in_channels , mid_channels , stride = 1 ,
22- conv_layer = ConvBnAct , act_fn = act_fn , zero_bn = True , bn_1st = True ,
23- groups = 1 , dw = False , div_groups = None ,
24- pool = None ,
25- se = None , sa = None ,
26- ):
28+ def __init__ (
29+ self ,
30+ expansion : int ,
31+ in_channels : int ,
32+ mid_channels : int ,
33+ stride : int = 1 ,
34+ conv_layer = ConvBnAct ,
35+ act_fn : nn .Module = act_fn ,
36+ zero_bn : bool = True ,
37+ bn_1st : bool = True ,
38+ groups : int = 1 ,
39+ dw : bool = False ,
40+ div_groups : Union [None , int ] = None ,
41+ pool : Union [nn .Module , None ] = None ,
42+ se : Union [nn .Module , None ] = None ,
43+ sa : Union [nn .Module , None ] = None ,
44+ ):
2745 super ().__init__ ()
46+ # pool defined at ModelConstructor.
2847 out_channels , in_channels = mid_channels * expansion , in_channels * expansion
2948 if div_groups is not None : # check if groups != 1 and div_groups
3049 groups = int (mid_channels / div_groups )
50+
3151 if stride != 1 :
3252 if pool is None :
3353 self .reduce = conv_layer (in_channels , in_channels , 1 , stride = 2 )
@@ -36,23 +56,69 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
3656 self .reduce = pool
3757 else :
3858 self .reduce = None
39- layers = [("conv_0" , conv_layer (in_channels , mid_channels , 3 , stride = 1 ,
40- act_fn = act_fn , bn_1st = bn_1st , groups = in_channels if dw else groups )),
41- ("conv_1" , conv_layer (mid_channels , out_channels , 3 , zero_bn = zero_bn ,
42- act_fn = False , bn_1st = bn_1st , groups = mid_channels if dw else groups ))
43- ] if expansion == 1 else [
44- ("conv_0" , conv_layer (in_channels , mid_channels , 1 , act_fn = act_fn , bn_1st = bn_1st )),
45- ("conv_1" , conv_layer (mid_channels , mid_channels , 3 , stride = 1 , act_fn = act_fn , bn_1st = bn_1st ,
46- groups = mid_channels if dw else groups )),
47- ("conv_2" , conv_layer (
48- mid_channels , out_channels , 1 , zero_bn = zero_bn , act_fn = False , bn_1st = bn_1st ))
49- ]
59+ if expansion == 1 :
60+ layers = [
61+ ("conv_0" , conv_layer (
62+ in_channels ,
63+ mid_channels ,
64+ 3 ,
65+ stride = 1 ,
66+ act_fn = act_fn ,
67+ bn_1st = bn_1st ,
68+ groups = in_channels if dw else groups ,
69+ ),),
70+ ("conv_1" , conv_layer (
71+ mid_channels ,
72+ out_channels ,
73+ 3 ,
74+ zero_bn = zero_bn ,
75+ act_fn = False ,
76+ bn_1st = bn_1st ,
77+ groups = mid_channels if dw else groups ,
78+ ),),
79+ ]
80+ else :
81+ layers = [
82+ ("conv_0" , conv_layer (
83+ in_channels ,
84+ mid_channels ,
85+ 1 ,
86+ act_fn = act_fn ,
87+ bn_1st = bn_1st ,
88+ ),),
89+ ("conv_1" , conv_layer (
90+ mid_channels ,
91+ mid_channels ,
92+ 3 ,
93+ stride = 1 ,
94+ act_fn = act_fn ,
95+ bn_1st = bn_1st ,
96+ groups = mid_channels if dw else groups ,
97+ ),),
98+ ("conv_2" , conv_layer (
99+ mid_channels ,
100+ out_channels ,
101+ 1 ,
102+ zero_bn = zero_bn ,
103+ act_fn = False ,
104+ bn_1st = bn_1st ,
105+ ),), # noqa E501
106+ ]
50107 if se :
51- layers .append (('se' , se (out_channels )))
108+ layers .append (("se" , se (out_channels )))
52109 if sa :
53- layers .append (('sa' , sa (out_channels )))
110+ layers .append (("sa" , sa (out_channels )))
54111 self .convs = nn .Sequential (OrderedDict (layers ))
55- self .id_conv = None if in_channels == out_channels else conv_layer (in_channels , out_channels , 1 , act_fn = False )
112+ if in_channels != out_channels :
113+ self .id_conv = conv_layer (
114+ in_channels ,
115+ out_channels ,
116+ 1 ,
117+ stride = 1 ,
118+ act_fn = False ,
119+ )
120+ else :
121+ self .id_conv = None
56122 self .merge = act_fn
57123
58124 def forward (self , x ):
@@ -62,6 +128,6 @@ def forward(self, x):
62128 return self .merge (self .convs (x ) + identity )
63129
64130
65- yaresnet_parameters = {'block' : YaResBlock , 'stem_sizes' : [3 , 32 , 64 , 64 ], 'act_fn' : Mish (), 'stem_stride_on' : 1 }
66- yaresnet34 = partial (Net , name = 'YaResnet34' , expansion = 1 , layers = [3 , 4 , 6 , 3 ], ** yaresnet_parameters )
67- yaresnet50 = partial (Net , name = 'YaResnet50' , expansion = 4 , layers = [3 , 4 , 6 , 3 ], ** yaresnet_parameters )
131+ # yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1}
132+ # yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters)
133+ # yaresnet50 = partial(Net, name='YaResnet50', expansion=4, layers=[3, 4, 6, 3], **yaresnet_parameters)
0 commit comments