11# YaResBlock - former NewResBlock.
22# Yet another ResNet.
33
4- import torch .nn as nn
5- from functools import partial
64from collections import OrderedDict
7- from .layers import ConvBnAct
8- from .net import Net
5+ from typing import Union
6+
7+ import torch .nn as nn
98from torch .nn import Mish
109
10+ from .layers import ConvBnAct
11+ from .model_constructor import CfgMC , ModelConstructor
1112
12- __all__ = ['YaResBlock' , 'yaresnet_parameters' , 'yaresnet34' , 'yaresnet50' ]
13+ __all__ = [
14+ 'YaResBlock' ,
15+ 'yaresnet34' ,
16+ 'yaresnet50' ,
17+ ]
1318
1419
1520act_fn = nn .ReLU (inplace = True )
1823class YaResBlock (nn .Module ):
1924 '''YaResBlock. Reduce by pool instead of stride 2'''
2025
21- def __init__ (self , expansion , in_channels , mid_channels , stride = 1 ,
22- conv_layer = ConvBnAct , act_fn = act_fn , zero_bn = True , bn_1st = True ,
23- groups = 1 , dw = False , div_groups = None ,
24- pool = None ,
25- se = None , sa = None ,
26- ):
26+ def __init__ (
27+ self ,
28+ expansion : int ,
29+ in_channels : int ,
30+ mid_channels : int ,
31+ stride : int = 1 ,
32+ conv_layer = ConvBnAct ,
33+ act_fn : nn .Module = act_fn ,
34+ zero_bn : bool = True ,
35+ bn_1st : bool = True ,
36+ groups : int = 1 ,
37+ dw : bool = False ,
38+ div_groups : Union [None , int ] = None ,
39+ pool : Union [nn .Module , None ] = None ,
40+ se : Union [nn .Module , None ] = None ,
41+ sa : Union [nn .Module , None ] = None ,
42+ ):
2743 super ().__init__ ()
44+ # pool defined at ModelConstructor.
2845 out_channels , in_channels = mid_channels * expansion , in_channels * expansion
2946 if div_groups is not None : # check if groups != 1 and div_groups
3047 groups = int (mid_channels / div_groups )
48+
3149 if stride != 1 :
3250 if pool is None :
3351 self .reduce = conv_layer (in_channels , in_channels , 1 , stride = 2 )
@@ -36,23 +54,69 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
3654 self .reduce = pool
3755 else :
3856 self .reduce = None
39- layers = [("conv_0" , conv_layer (in_channels , mid_channels , 3 , stride = 1 ,
40- act_fn = act_fn , bn_1st = bn_1st , groups = in_channels if dw else groups )),
41- ("conv_1" , conv_layer (mid_channels , out_channels , 3 , zero_bn = zero_bn ,
42- act_fn = False , bn_1st = bn_1st , groups = mid_channels if dw else groups ))
43- ] if expansion == 1 else [
44- ("conv_0" , conv_layer (in_channels , mid_channels , 1 , act_fn = act_fn , bn_1st = bn_1st )),
45- ("conv_1" , conv_layer (mid_channels , mid_channels , 3 , stride = 1 , act_fn = act_fn , bn_1st = bn_1st ,
46- groups = mid_channels if dw else groups )),
47- ("conv_2" , conv_layer (
48- mid_channels , out_channels , 1 , zero_bn = zero_bn , act_fn = False , bn_1st = bn_1st ))
49- ]
57+ if expansion == 1 :
58+ layers = [
59+ ("conv_0" , conv_layer (
60+ in_channels ,
61+ mid_channels ,
62+ 3 ,
63+ stride = 1 ,
64+ act_fn = act_fn ,
65+ bn_1st = bn_1st ,
66+ groups = in_channels if dw else groups ,
67+ ),),
68+ ("conv_1" , conv_layer (
69+ mid_channels ,
70+ out_channels ,
71+ 3 ,
72+ zero_bn = zero_bn ,
73+ act_fn = False ,
74+ bn_1st = bn_1st ,
75+ groups = mid_channels if dw else groups ,
76+ ),),
77+ ]
78+ else :
79+ layers = [
80+ ("conv_0" , conv_layer (
81+ in_channels ,
82+ mid_channels ,
83+ 1 ,
84+ act_fn = act_fn ,
85+ bn_1st = bn_1st ,
86+ ),),
87+ ("conv_1" , conv_layer (
88+ mid_channels ,
89+ mid_channels ,
90+ 3 ,
91+ stride = 1 ,
92+ act_fn = act_fn ,
93+ bn_1st = bn_1st ,
94+ groups = mid_channels if dw else groups ,
95+ ),),
96+ ("conv_2" , conv_layer (
97+ mid_channels ,
98+ out_channels ,
99+ 1 ,
100+ zero_bn = zero_bn ,
101+ act_fn = False ,
102+ bn_1st = bn_1st ,
103+ ),), # noqa E501
104+ ]
50105 if se :
51- layers .append (('se' , se (out_channels )))
106+ layers .append (("se" , se (out_channels )))
52107 if sa :
53- layers .append (('sa' , sa (out_channels )))
108+ layers .append (("sa" , sa (out_channels )))
54109 self .convs = nn .Sequential (OrderedDict (layers ))
55- self .id_conv = None if in_channels == out_channels else conv_layer (in_channels , out_channels , 1 , act_fn = False )
110+ if in_channels != out_channels :
111+ self .id_conv = conv_layer (
112+ in_channels ,
113+ out_channels ,
114+ 1 ,
115+ stride = 1 ,
116+ act_fn = False ,
117+ )
118+ else :
119+ self .id_conv = None
56120 self .merge = act_fn
57121
58122 def forward (self , x ):
@@ -62,6 +126,21 @@ def forward(self, x):
62126 return self .merge (self .convs (x ) + identity )
63127
64128
65- yaresnet_parameters = {'block' : YaResBlock , 'stem_sizes' : [3 , 32 , 64 , 64 ], 'act_fn' : Mish (), 'stem_stride_on' : 1 }
66- yaresnet34 = partial (Net , name = 'YaResnet34' , expansion = 1 , layers = [3 , 4 , 6 , 3 ], ** yaresnet_parameters )
67- yaresnet50 = partial (Net , name = 'YaResnet50' , expansion = 4 , layers = [3 , 4 , 6 , 3 ], ** yaresnet_parameters )
129+ yaresnet34 = ModelConstructor .from_cfg (
130+ CfgMC (
131+ name = 'YaResnet34' ,
132+ block = YaResBlock ,
133+ expansion = 1 ,
134+ layers = [3 , 4 , 6 , 3 ],
135+ act_fn = Mish (),
136+ )
137+ )
138+ yaresnet50 = ModelConstructor .from_cfg (
139+ CfgMC (
140+ name = 'YaResnet50' ,
141+ block = YaResBlock ,
142+ act_fn = Mish (),
143+ expansion = 4 ,
144+ layers = [3 , 4 , 6 , 3 ],
145+ )
146+ )
0 commit comments