1010Implementation for timm by / Copyright 2023, Fredo Guan
1111"""
1212
13- from typing import Any , Dict , Optional , Tuple
13+ import math
14+ from typing import Any , Dict , Optional , Tuple , Union
1415
1516import torch
1617import torch .nn as nn
2627
2728__all__ = ['DependencyViT' ]
2829
30+ class TokenPruner (nn .Module ):
31+ def __init__ (
32+ self ,
33+ prune_ratio : float ,
34+ prune_index : int ,
35+ ):
36+ super ().__init__ ()
37+ self .pct_kept_tokens = (1 - prune_index * prune_ratio ) / (1 - (prune_index - 1 ) * prune_ratio )
38+
39+ def forward (self , x : torch .Tensor , scores : torch .Tensor ): # [B, N, C], [B, N]
40+ _ , N , C = x .shape
41+ topk_indices = scores .topk (math .floor (self .pct_kept_tokens * N ), sorted = False ) # [B, N']
42+ topk_indices = topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C ) # [B, N', C]
43+ return x .gather (1 , topk_indices )
44+
2945
30- # FIXME there is nearly no difference between this and stock attn, allowing sdpa to be used if a workaround can be found
3146class ReversedAttention (nn .Module ):
3247 dependency_mask : Optional [torch .Tensor ]
3348
@@ -48,9 +63,9 @@ def __init__(
4863 self .scale = self .head_dim ** - 0.5
4964 self .track_dependency_mask = False
5065 self .dependency_mask = None
51- self .head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0
66+ self .head_selector_temperature = 0.1 # appendix D.1
5267
53- self .head_selector = nn .Linear (dim , num_heads , bias = False )
68+ self .head_selector = nn .Linear (dim , num_heads , bias = False ) # FIXME is there a bias term?
5469
5570 self .message_controller = Mlp (
5671 in_features = dim ,
@@ -59,7 +74,9 @@ def __init__(
5974 act_layer = nn .GELU ,
6075 bias = False , # FIXME is there a bias term?
6176 )
62-
77+
78+ self .token_pruner = None
79+
6380 self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
6481 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
6582 self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
@@ -86,8 +103,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
86103 attn = self .attn_drop (attn ).transpose (- 2 , - 1 ) # this transpose prevents use of sdpa
87104 attn = attn * p * m # [B, n_h, N, N]
88105 x = attn @ v
89-
90- self .dependency_mask = attn .sum (1 ) if self .track_dependency_mask else None
106+
107+ # FIXME messy way to handle
108+ if self .track_dependency_mask or not isinstance (self .token_pruner , nn .Identity ()):
109+ dependency_mask = attn .detach ().sum (1 ) # [B, N, N]
110+ self .dependency_mask = dependency_mask if self .track_dependency_mask else None
111+ #FIXME how to prune
112+ x = self .token_pruner (x , dependency_mask .sum (- 1 )) if self .token_pruner else x # dependency mask weights(sum)
113+ #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
114+ #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
115+ #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
116+
91117
92118 x = x .transpose (1 , 2 ).reshape (B , N , C )
93119 x = self .proj (x )
@@ -161,7 +187,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
161187# FIXME verify against reference impl
162188
163189class DependencyViT (VisionTransformer ):
164- def __init__ (self , * args , ** kwargs ):
190+ def __init__ (
191+ self ,
192+ prune_layers : Optional [Union [List [int ], Tuple [int ]]] = None ,
193+ prune_ratio : Optional [float ] = None ,
194+ * args ,
195+ ** kwargs
196+ ):
165197 super ().__init__ (
166198 * args ,
167199 ** kwargs ,
@@ -172,6 +204,19 @@ def __init__(self, *args, **kwargs):
172204 init_values = 1e-6 ,
173205 fc_norm = False ,
174206 )
207+
208+ if prune_layers is not None :
209+ self .prune_layers = sorted (list (dict .fromkeys (prune_layers )))
210+ self .prune_ratio = prune_ratio
211+
212+ # FIXME reword these assertions
213+ assert max (self .prune_layers ) <= len (self .blocks ), "1 or more pruned layer indices are greater than model depth"
214+ assert self .prune_ratio * len (self .prune_layers ) < 1 , "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"
215+
216+ self .prune_layers = [x - 1 for x in self .prune_layers ] # convert counting numbers to nn.Sequential indicess
217+ for prune_index , layer in enumerate (prune_layers , 1 ):
218+ self .blocks [layer ].attn .token_pruner = TokenPruner (self .prune_ratio , prune_index )
219+
175220
176221 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
177222 x = self .patch_embed (x )
@@ -191,6 +236,23 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
191236 x = self .norm (x )
192237 x = x * m .transpose (1 , 3 ).squeeze (- 1 )
193238 return x
239+
240+ def track_dependency_mask (self , track : bool = True ):
241+ for block in self .blocks :
242+ if block .attn .track_dependency_mask is not track :
243+ block .attn .dependency_mask = None
244+ block .attn .track_dependency_mask = track
245+
246+ def get_dependency_mask (self , layers : Optional [Union [List [int ], Tuple [int ]]] = None ):
247+ # L' * [B, N, N]
248+ # L' * [B, N', N']
249+ result = []
250+ layers = range (len (self .blocks )) if not layers
251+ for layer in layers :
252+ result .append (self .blocks [layer ].attn .dependency_mask )
253+ return result
254+
255+
194256
195257
196258def _cfg (url : str = '' , ** kwargs ) -> Dict [str , Any ]:
@@ -212,6 +274,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
212274
213275default_cfgs = {
214276 'dependencyvit_tiny_patch16_224.untrained' : _cfg (url = '' ),
277+ 'dependencyvit_small_patch16_224.untrained' : _cfg (url = '' ),
278+
279+ 'dependencyvit_lite_tiny_patch16_224.untrained' : _cfg (url = '' ),
215280}
216281
217282
@@ -240,4 +305,10 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend
240305def dependencyvit_small_patch16_224 (pretrained : bool = False , ** kwargs ) -> DependencyViT :
241306 model_args = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 12 )
242307 model = _create_dependencyvit ('dependencyvit_tiny_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
308+ return model
309+
310+ @register_model
311+ def dependencyvit_lite_tiny_patch16_224 (pretrained : bool = False , ** kwargs ) -> DependencyViT :
312+ model_args = dict (patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 12 , prune_layers = [2 , 5 , 8 , 11 ], prune_ratio = 0.16 )
313+ model = _create_dependencyvit ('dependencyvit_tiny_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
243314 return model
0 commit comments