Skip to content

Commit 02d0f27

Browse files
committed
cleanup davit padding
1 parent c715c72 commit 02d0f27

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

timm/models/davit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ def __init__(
7979

8080
def forward(self, x: Tensor):
8181
B, C, H, W = x.shape
82-
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
83-
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
82+
pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
83+
pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
84+
x = F.pad(x, (0, pad_r, 0, pad_b))
8485
x = self.conv(x)
8586
x = self.norm(x)
8687
return x
@@ -113,8 +114,9 @@ def forward(self, x: Tensor):
113114
x = self.norm(x)
114115
if self.even_k:
115116
k_h, k_w = self.conv.kernel_size
116-
x = F.pad(x, (0, (k_w - W % k_w) % k_w))
117-
x = F.pad(x, (0, 0, 0, (k_h - H % k_h) % k_h))
117+
pad_r = (k_w - W % k_w) % k_w
118+
pad_b = (k_h - H % k_h) % k_h
119+
x = F.pad(x, (0, pad_r , 0, pad_b))
118120
x = self.conv(x)
119121
return x
120122

0 commit comments

Comments
 (0)