Skip to content

Commit dd3b96c

Browse files
committed
Fix features intermediates for NCHW inputs, patch variable size inputs need more code
1 parent b3ca8fd commit dd3b96c

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

timm/models/naflexvit.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,11 @@ def forward_intermediates(
10071007
patch_coord = x['patch_coord']
10081008
patch_valid = x['patch_valid']
10091009
patches = x['patches']
1010+
assert False, 'WIP, patch mode needs more work'
10101011
else:
10111012
patches = x
1013+
height, width = x.shape[-2:]
1014+
H, W = self.embeds.dynamic_feat_size((height, width))
10121015

10131016
# Create attention mask if patch_type is provided and mask is not
10141017
if mask is None and patch_valid is not None:
@@ -1040,12 +1043,6 @@ def forward_intermediates(
10401043

10411044
if reshape:
10421045
# reshape to BCHW output format
1043-
grid_size = self.embeds.pos_embed_grid_size
1044-
if hasattr(self.embeds, 'dynamic_feat_size') and len(x.shape) >= 4:
1045-
_, height, width, _ = x.shape if len(x.shape) == 4 else (None, *x.shape[-3:-1], None)
1046-
H, W = self.embeds.dynamic_feat_size((height, width))
1047-
else:
1048-
H, W = grid_size
10491046
intermediates = [
10501047
y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
10511048
for y in intermediates

0 commit comments

Comments
 (0)