1- """ DependencyViT (FIXME WIP)
1+ """ DependencyViT
22
33From-scratch implementation of DependencyViT in PyTorch
44
@@ -106,19 +106,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
106106 x = attn @ v
107107 x = x .transpose (1 , 2 ).reshape (B , N , C )
108108
109- '''
110- # FIXME messy way to handle
111- if self.track_dependency_mask or self.token_pruner:
112- dependency_mask = attn.detach().sum(1) # [B, N, N]
113- self.dependency_mask = dependency_mask if self.track_dependency_mask else None
114- #FIXME how to prune
115- x = self.token_pruner(x, dependency_mask.sum(-1)) if self.token_pruner else x # dependency mask weights(sum)
116- #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
117- #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
118- #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
119- '''
109+
110+ #FIXME absolute value?
120111 self .dependency_mask = attn .detach ().sum (1 ) if self .track_dependency_mask else None # [B, N, N]
121112
113+ #FIXME which pruning mask?
114+
115+ # [B, N]
122116 #prune_mask = attn.detach().sum(1).sum(-1)
123117 #prune_mask = attn.detach().sum(1).abs().sum(-1)
124118 #prune_mask = attn.detach().abs().sum((1, -1))
@@ -196,9 +190,9 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
196190 x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
197191 return (x , m )
198192
199- # FIXME lite model variants
200- # FIXME toggle and retrieve dependency masks
193+
201194# FIXME verify against reference impl
195+ # FIXME train weights that meet or exceed results from paper
202196
203197class DependencyViT (VisionTransformer ):
204198 def __init__ (
@@ -207,24 +201,23 @@ def __init__(
207201 prune_ratio : Optional [float ] = None ,
208202 * args ,
209203 ** kwargs
210- ):
204+ ): - > None :
211205 super ().__init__ (
212- * args ,
206+ * args ,
213207 ** kwargs ,
214- block_fn = DependencyViTBlock ,
208+ block_fn = DependencyViTBlock ,
215209 class_token = False ,
216- global_pool = 'avg' ,
217- qkv_bias = False ,
218- init_values = 1e-6 ,
210+ global_pool = 'avg' ,
211+ qkv_bias = False ,
212+ init_values = 1e-6 ,
219213 fc_norm = False ,
220214 )
221215
222216 if prune_layers is not None :
223217 self .prune_layers = sorted (list (dict .fromkeys (prune_layers )))
224218 self .prune_ratio = prune_ratio
225219
226- # FIXME reword these assertions
227- assert max (self .prune_layers ) <= len (self .blocks ), "1 or more pruned layer indices are greater than model depth"
220+ assert max (self .prune_layers ) <= len (self .blocks ), "1 or more pruned layer indices exceed model depth"
228221 assert self .prune_ratio * len (self .prune_layers ) < 1 , "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"
229222
230223 self .prune_layers = [x - 1 for x in self .prune_layers ] # convert counting numbers to nn.Sequential indicess
0 commit comments