Skip to content

Commit b7a568f

Browse files
committed
Fix torchscript issue in bat
1 parent d17b374 commit b7a568f

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

timm/models/layers/non_local_attn.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_laye
8181
self.groups = groups
8282
self.in_channels = in_channels
8383

84-
def resize_mat(self, x, t):
84+
def resize_mat(self, x, t: int):
8585
B, C, block_size, block_size1 = x.shape
8686
assert block_size == block_size1
8787
if t <= 1:
@@ -100,10 +100,8 @@ def forward(self, x):
100100
out = self.conv1(x)
101101
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
102102
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
103-
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
104-
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
105-
p = F.sigmoid(p)
106-
q = F.sigmoid(q)
103+
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
104+
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
107105
p = p / p.sum(dim=3, keepdim=True)
108106
q = q / q.sum(dim=2, keepdim=True)
109107
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(

0 commit comments

Comments
 (0)