@@ -75,7 +75,7 @@ def __init__(
7575 bias = False , # FIXME is there a bias term?
7676 )
7777
78- self .token_pruner = None
78+ # self.token_pruner = None
7979
8080 self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
8181 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
@@ -105,7 +105,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
105105 x = attn @ v
106106 x = x .transpose (1 , 2 ).reshape (B , N , C )
107107
108-
108+ '''
109109 # FIXME messy way to handle
110110 if self.track_dependency_mask or self.token_pruner:
111111 dependency_mask = attn.detach().sum(1) # [B, N, N]
@@ -115,12 +115,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
115115 #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
116116 #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
117117 #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
118-
119-
118+ '''
119+ self .dependency_mask = attn .detach ().sum (1 ) if self .track_dependency_mask else None # [B, N, N]
120+
121+ prune_mask = attn .detach ().sum (1 ).sum (- 1 )
122+ #prune_mask = attn.detach().sum(1).abs().sum(-1)
123+ #prune_mask = attn.detach().abs().sum(1).sum(-1)
124+ #prune_mask = m.reshape(B, N)
120125
121126 x = self .proj (x )
122127 x = self .proj_drop (x )
123- return (x , m )
128+ return (x , m , prune_mask )
124129
125130class LayerScale (nn .Module ):
126131 def __init__ (
@@ -166,6 +171,8 @@ def __init__(
166171 )
167172 self .ls1 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
168173 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
174+
175+ self .token_pruner = None
169176
170177 self .norm2 = norm_layer (dim )
171178 self .mlp = mlp_layer (
@@ -179,8 +186,9 @@ def __init__(
179186
180187 def forward (self , in_tuple : Tuple [torch .Tensor , torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor ]:
181188 x , m = in_tuple
182- x_new , m = self .attn ((self .norm1 (x ), m ))
189+ x_new , m , prune_mask = self .attn ((self .norm1 (x ), m ))
183190 x = x + self .drop_path1 (self .ls1 (x_new ))
191+ x = self .token_pruner (x , prune_mask ) if self .token_pruner else x
184192 x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
185193 return (x , m )
186194
@@ -217,7 +225,7 @@ def __init__(
217225
218226 self .prune_layers = [x - 1 for x in self .prune_layers ] # convert counting numbers to nn.Sequential indicess
219227 for prune_index , layer in enumerate (prune_layers , 1 ):
220- self .blocks [layer ].attn . token_pruner = TokenPruner (self .prune_ratio , prune_index )
228+ self .blocks [layer ].token_pruner = TokenPruner (self .prune_ratio , prune_index )
221229
222230
223231 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
0 commit comments