@@ -52,22 +52,15 @@ class EcaModule(nn.Module):
5252 def __init__ (self , channels = None , kernel_size = 3 , gamma = 2 , beta = 1 ):
5353 super (EcaModule , self ).__init__ ()
5454 assert kernel_size % 2 == 1
55-
5655 if channels is not None :
5756 t = int (abs (math .log (channels , 2 ) + beta ) / gamma )
5857 kernel_size = max (t if t % 2 else t + 1 , 3 )
5958
60- self .avg_pool = nn .AdaptiveAvgPool2d (1 )
6159 self .conv = nn .Conv1d (1 , 1 , kernel_size = kernel_size , padding = (kernel_size - 1 ) // 2 , bias = False )
6260
6361 def forward (self , x ):
64- # Feature descriptor on the global spatial information
65- y = self .avg_pool (x )
66- # Reshape for convolution
67- y = y .view (x .shape [0 ], 1 , - 1 )
68- # Two different branches of ECA module
62+ y = x .mean ((2 , 3 )).view (x .shape [0 ], 1 , - 1 ) # view for 1d conv
6963 y = self .conv (y )
70- # Multi-scale information fusion
7164 y = y .view (x .shape [0 ], - 1 , 1 , 1 ).sigmoid ()
7265 return x * y .expand_as (x )
7366
@@ -95,30 +88,20 @@ class CecaModule(nn.Module):
9588 def __init__ (self , channels = None , kernel_size = 3 , gamma = 2 , beta = 1 ):
9689 super (CecaModule , self ).__init__ ()
9790 assert kernel_size % 2 == 1
98-
9991 if channels is not None :
10092 t = int (abs (math .log (channels , 2 ) + beta ) / gamma )
10193 kernel_size = max (t if t % 2 else t + 1 , 3 )
10294
103- self .avg_pool = nn .AdaptiveAvgPool2d (1 )
104- #pytorch circular padding mode is buggy as of pytorch 1.4
105- #see https://github.com/pytorch/pytorch/pull/17240
106-
107- #implement manual circular padding
95+ # PyTorch circular padding mode is buggy as of pytorch 1.4
96+ # see https://github.com/pytorch/pytorch/pull/17240
97+ # implement manual circular padding
10898 self .conv = nn .Conv1d (1 , 1 , kernel_size = kernel_size , padding = 0 , bias = False )
10999 self .padding = (kernel_size - 1 ) // 2
110100
111101 def forward (self , x ):
112- # Feature descriptor on the global spatial information
113- y = self .avg_pool (x )
114-
102+ y = x .mean ((2 , 3 )).view (x .shape [0 ], 1 , - 1 )
115103 # Manually implement circular padding, F.pad does not seemed to be bugged
116- y = F .pad (y .view (x .shape [0 ], 1 , - 1 ), (self .padding , self .padding ), mode = 'circular' )
117-
118- # Two different branches of ECA module
104+ y = F .pad (y , (self .padding , self .padding ), mode = 'circular' )
119105 y = self .conv (y )
120-
121- # Multi-scale information fusion
122106 y = y .view (x .shape [0 ], - 1 , 1 , 1 ).sigmoid ()
123-
124107 return x * y .expand_as (x )
0 commit comments