We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4731e4e commit db06b56Copy full SHA for db06b56
timm/models/vision_transformer.py
@@ -635,13 +635,14 @@ def _intermediate_layers(
635
) -> List[torch.Tensor]:
636
outputs, num_blocks = [], len(self.blocks)
637
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
638
+ last_index_to_take = max(take_indices)
639
640
# forward pass
641
x = self.patch_embed(x)
642
x = self._pos_embed(x)
643
x = self.patch_drop(x)
644
x = self.norm_pre(x)
- for i, blk in enumerate(self.blocks):
645
+ for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
646
x = blk(x)
647
if i in take_indices:
648
outputs.append(x)
0 commit comments