@@ -81,7 +81,7 @@ def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_laye
8181 self .groups = groups
8282 self .in_channels = in_channels
8383
84- def resize_mat (self , x , t ):
84+ def resize_mat (self , x , t : int ):
8585 B , C , block_size , block_size1 = x .shape
8686 assert block_size == block_size1
8787 if t <= 1 :
@@ -100,10 +100,8 @@ def forward(self, x):
100100 out = self .conv1 (x )
101101 rp = F .adaptive_max_pool2d (out , (self .block_size , 1 ))
102102 cp = F .adaptive_max_pool2d (out , (1 , self .block_size ))
103- p = self .conv_p (rp ).view (B , self .groups , self .block_size , self .block_size )
104- q = self .conv_q (cp ).view (B , self .groups , self .block_size , self .block_size )
105- p = F .sigmoid (p )
106- q = F .sigmoid (q )
103+ p = self .conv_p (rp ).view (B , self .groups , self .block_size , self .block_size ).sigmoid ()
104+ q = self .conv_q (cp ).view (B , self .groups , self .block_size , self .block_size ).sigmoid ()
107105 p = p / p .sum (dim = 3 , keepdim = True )
108106 q = q / q .sum (dim = 2 , keepdim = True )
109107 p = p .view (B , self .groups , 1 , self .block_size , self .block_size ).expand (x .size (
0 commit comments