@@ -36,11 +36,12 @@ def __init__(
3636 super ().__init__ ()
3737 self .pct_kept_tokens = (1 - prune_index * prune_ratio ) / (1 - (prune_index - 1 ) * prune_ratio )
3838
39- def forward (self , x : torch .Tensor , scores : torch .Tensor ): # [B, N, C], [B, N]
40- _ , N , C = x .shape
39+ def forward (self , x : torch .Tensor , m : torch . Tensor , scores : torch .Tensor ): # [B, N, C], [B, 1, 1, N ], [B, N]
40+ B , N , C = x .shape
4141 topk_indices = scores .topk (math .floor (self .pct_kept_tokens * N ), sorted = False )[1 ] # [B, N']
42- topk_indices = topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C ) # [B, N', C]
43- return x .gather (1 , topk_indices )
42+ x = x .gather (1 , topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C )) # [B, N', C]
43+ m = m .gather (3 , topk_indices .unsqueeze (1 ).unsqueeze (1 )) # [B, 1, 1, N']
44+ return (x , m )
4445
4546
4647class ReversedAttention (nn .Module ):
@@ -188,7 +189,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
188189 x , m = in_tuple
189190 x_new , m , prune_mask = self .attn ((self .norm1 (x ), m ))
190191 x = x + self .drop_path1 (self .ls1 (x_new ))
191- x = self .token_pruner (x , prune_mask ) if self .token_pruner else x
192+ x , m = self .token_pruner (x , m , prune_mask ) if self .token_pruner else ( x , m )
192193 x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
193194 return (x , m )
194195
0 commit comments