We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1b0fb07 commit ee3fa6eCopy full SHA for ee3fa6e
timm/models/dependencyvit.py
@@ -38,7 +38,7 @@ def __init__(
38
39
def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N]
40
_, N, C = x.shape
41
- topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False) # [B, N']
+ topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N']
42
topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C]
43
return x.gather(1, topk_indices)
44
0 commit comments