@@ -32,11 +32,12 @@ def __init__(
3232 self ,
3333 prune_ratio : float ,
3434 prune_index : int ,
35- ):
35+ ) -> None :
3636 super ().__init__ ()
3737 self .pct_kept_tokens = (1 - prune_index * prune_ratio ) / (1 - (prune_index - 1 ) * prune_ratio )
3838
39- def forward (self , x : torch .Tensor , m : torch .Tensor , scores : torch .Tensor ): # [B, N, C], [B, 1, 1, N], [B, N]
39+ # [B, N, C], [B, 1, 1, N], [B, N] -> [B, N', C], [B, 1, 1, N']
40+ def forward (self , x : torch .Tensor , m : torch .Tensor , scores : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
4041 B , N , C = x .shape
4142 topk_indices = scores .topk (math .floor (self .pct_kept_tokens * N ), sorted = False )[1 ] # [B, N']
4243 x = x .gather (1 , topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C )) # [B, N', C]
@@ -86,8 +87,8 @@ def __init__(
8687 self .proj_drop = nn .Dropout (proj_drop )
8788
8889 # m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1)
89- def forward ( self , in_tuple : Tuple [ torch . Tensor , torch . Tensor ]) -> Tuple [ torch . Tensor , torch . Tensor ]:
90- x , m = in_tuple # [B, N, C], [B, 1, 1, N]
90+ # [B, N, C], [B, 1, 1, N] -> [B, N, C], [B, 1, 1, N], [B, N]
91+ def forward ( self , x : torch . Tensor , m : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ]:
9192 B , N , C = x .shape
9293 qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
9394 q , k , v = qkv .unbind (0 )
@@ -112,7 +113,6 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
112113
113114 #FIXME which pruning mask?
114115
115- # [B, N]
116116 #prune_mask = attn.detach().sum(1).sum(-1)
117117 #prune_mask = attn.detach().sum(1).abs().sum(-1)
118118 #prune_mask = attn.detach().abs().sum((1, -1))
@@ -184,7 +184,7 @@ def __init__(
184184
185185 def forward (self , in_tuple : Tuple [torch .Tensor , torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor ]:
186186 x , m = in_tuple
187- x_new , m , prune_mask = self .attn (( self .norm1 (x ), m ) )
187+ x_new , m , prune_mask = self .attn (self .norm1 (x ), m )
188188 x = x + self .drop_path1 (self .ls1 (x_new ))
189189 x , m = self .token_pruner (x , m , prune_mask ) if self .token_pruner else (x , m )
190190 x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
@@ -201,7 +201,7 @@ def __init__(
201201 prune_ratio : Optional [float ] = None ,
202202 * args ,
203203 ** kwargs
204- ): - > None :
204+ ) -> None :
205205 super ().__init__ (
206206 * args ,
207207 ** kwargs ,
@@ -244,13 +244,13 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
244244 x = x * m .transpose (1 , 3 ).squeeze (- 1 )
245245 return x
246246
247- def track_dependency_mask (self , track : bool = True ):
247+ def track_dependency_mask (self , track : bool = True ) -> None :
248248 for block in self .blocks :
249249 if block .attn .track_dependency_mask is not track :
250250 block .attn .dependency_mask = None
251251 block .attn .track_dependency_mask = track
252252
253- def get_dependency_mask (self , layers : Optional [Union [List [int ], Tuple [int ]]] = None ):
253+ def get_dependency_mask (self , layers : Optional [Union [List [int ], Tuple [int ]]] = None ) -> List [ torch . Tensor ] :
254254 # L' * [B, N, N]
255255 # L' * [B, N', N']
256256 result = []
0 commit comments