Skip to content

Commit 4731e4e

Browse files
committed
Modified ViT get_intermediate_layers() to support dynamic image size
1 parent 6e6f368 commit 4731e4e

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

timm/models/vision_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,12 @@ def get_intermediate_layers(
667667
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
668668

669669
if reshape:
670-
grid_size = self.patch_embed.grid_size
670+
patch_size = self.patch_embed.patch_size
671+
batch, _, height, width = x.size()
671672
outputs = [
672-
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
673+
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
674+
.permute(0, 3, 1, 2)
675+
.contiguous()
673676
for out in outputs
674677
]
675678

0 commit comments

Comments
 (0)