22import torch .nn as nn
33import torch .nn .functional as F
44
5- from ..common .blocks import Conv2dReLU
5+ from ..common .blocks import Conv2dReLU , SCSEModule
66from ..base .model import Model
77
88
99class DecoderBlock (nn .Module ):
10- def __init__ (self , in_channels , out_channels , use_batchnorm = True ):
10+ def __init__ (self , in_channels , out_channels , use_batchnorm = True , attention_type = None ):
1111 super ().__init__ ()
12+ if attention_type is None :
13+ self .attention1 = nn .Identity ()
14+ self .attention2 = nn .Identity ()
15+ elif attention_type == 'scse' :
16+ self .attention1 = SCSEModule (in_channels )
17+ self .attention2 = SCSEModule (out_channels )
18+
1219 self .block = nn .Sequential (
1320 Conv2dReLU (in_channels , out_channels , kernel_size = 3 , padding = 1 , use_batchnorm = use_batchnorm ),
1421 Conv2dReLU (out_channels , out_channels , kernel_size = 3 , padding = 1 , use_batchnorm = use_batchnorm ),
@@ -19,7 +26,10 @@ def forward(self, x):
1926 x = F .interpolate (x , scale_factor = 2 , mode = 'nearest' )
2027 if skip is not None :
2128 x = torch .cat ([x , skip ], dim = 1 )
29+ x = self .attention1 (x )
30+
2231 x = self .block (x )
32+ x = self .attention2 (x )
2333 return x
2434
2535
@@ -38,6 +48,7 @@ def __init__(
3848 final_channels = 1 ,
3949 use_batchnorm = True ,
4050 center = False ,
51+ attention_type = None
4152 ):
4253 super ().__init__ ()
4354
@@ -50,11 +61,16 @@ def __init__(
5061 in_channels = self .compute_channels (encoder_channels , decoder_channels )
5162 out_channels = decoder_channels
5263
53- self .layer1 = DecoderBlock (in_channels [0 ], out_channels [0 ], use_batchnorm = use_batchnorm )
54- self .layer2 = DecoderBlock (in_channels [1 ], out_channels [1 ], use_batchnorm = use_batchnorm )
55- self .layer3 = DecoderBlock (in_channels [2 ], out_channels [2 ], use_batchnorm = use_batchnorm )
56- self .layer4 = DecoderBlock (in_channels [3 ], out_channels [3 ], use_batchnorm = use_batchnorm )
57- self .layer5 = DecoderBlock (in_channels [4 ], out_channels [4 ], use_batchnorm = use_batchnorm )
64+ self .layer1 = DecoderBlock (in_channels [0 ], out_channels [0 ],
65+ use_batchnorm = use_batchnorm , attention_type = attention_type )
66+ self .layer2 = DecoderBlock (in_channels [1 ], out_channels [1 ],
67+ use_batchnorm = use_batchnorm , attention_type = attention_type )
68+ self .layer3 = DecoderBlock (in_channels [2 ], out_channels [2 ],
69+ use_batchnorm = use_batchnorm , attention_type = attention_type )
70+ self .layer4 = DecoderBlock (in_channels [3 ], out_channels [3 ],
71+ use_batchnorm = use_batchnorm , attention_type = attention_type )
72+ self .layer5 = DecoderBlock (in_channels [4 ], out_channels [4 ],
73+ use_batchnorm = use_batchnorm , attention_type = attention_type )
5874 self .final_conv = nn .Conv2d (out_channels [4 ], final_channels , kernel_size = (1 , 1 ))
5975
6076 self .initialize ()
0 commit comments