@@ -217,6 +217,8 @@ class ParallelScalingBlock(nn.Module):
217217 Based on:
218218 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
219219 """
220+ fast_attn : Final [bool ]
221+
220222 def __init__ (
221223 self ,
222224 dim ,
@@ -232,33 +234,76 @@ def __init__(
232234 norm_layer = nn .LayerNorm
233235 ):
234236 super ().__init__ ()
235- self .norm1 = norm_layer (dim )
236- self .attn = Attention (
237- dim ,
238- num_heads = num_heads ,
239- qkv_bias = qkv_bias ,
240- qk_norm = qk_norm ,
241- attn_drop = attn_drop ,
242- proj_drop = drop ,
243- norm_layer = norm_layer ,
244- )
245- self .ls1 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
246- self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
237+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
238+ self .num_heads = num_heads
239+ self .head_dim = dim // num_heads
240+ self .scale = self .head_dim ** - 0.5
241+ self .fast_attn = hasattr (torch .nn .functional , 'scaled_dot_product_attention' ) # FIXME
242+ mlp_hidden_dim = int (mlp_ratio * dim )
243+ in_proj_out_dim = mlp_hidden_dim + 3 * dim
244+ out_proj_in_dim = mlp_hidden_dim + dim
245+
246+ self .in_norm = norm_layer (dim )
247+ self .in_proj = nn .Linear (dim , in_proj_out_dim , bias = qkv_bias )
248+ self .in_split = [mlp_hidden_dim ] + [dim ] * 3
249+ if qkv_bias :
250+ self .register_buffer ('qkv_bias' , None )
251+ self .register_parameter ('mlp_bias' , None )
252+ else :
253+ self .register_buffer ('qkv_bias' , torch .zeros (3 * dim ), persistent = False )
254+ self .mlp_bias = nn .Parameter (torch .zeros (mlp_hidden_dim ))
247255
248- self .norm2 = norm_layer (dim )
249- self .mlp = Mlp (
250- in_features = dim ,
251- hidden_features = int (dim * mlp_ratio ),
252- act_layer = act_layer ,
253- drop = drop ,
254- )
255- self .ls2 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
256- self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
256+ self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
257+ self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
258+ self .attn_drop = nn .Dropout (attn_drop )
259+ self .attn_out_proj = nn .Linear (dim , dim )
260+
261+ self .mlp_drop = nn .Dropout (drop )
262+ self .mlp_act = act_layer ()
263+ self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim )
264+
265+ self .ls = LayerScale (dim , init_values = init_values ) if init_values is not None else nn .Identity ()
266+ self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
257267
258268 def forward (self , x ):
259- y1 = self .drop_path1 (self .ls1 (self .attn (self .norm1 (x ))))
260- y2 = self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
261- x = x + y1 + y2
269+ B , N , C = x .shape
270+
271+ # Combined MLP fc1 & qkv projections
272+ y = self .in_norm (x )
273+ if self .mlp_bias is not None :
274+ # Concat constant zero-bias for qkv w/ trainable mlp_bias.
275+ # Appears faster than adding to x_mlp separately
276+ y = F .linear (y , self .in_proj .weight , torch .cat ((self .qkv_bias , self .mlp_bias )))
277+ else :
278+ y = self .in_proj (y )
279+ x_mlp , q , k , v = torch .split (y , self .in_split , dim = - 1 )
280+
281+ # Dot product attention w/ qk norm
282+ q = self .q_norm (q .view (B , N , self .num_heads , self .head_dim )).transpose (1 , 2 )
283+ k = self .k_norm (k .view (B , N , self .num_heads , self .head_dim )).transpose (1 , 2 )
284+ v = v .view (B , N , self .num_heads , self .head_dim ).transpose (1 , 2 )
285+ if self .fast_attn :
286+ x_attn = F .scaled_dot_product_attention (
287+ q , k , v ,
288+ dropout_p = self .attn_drop .p ,
289+ )
290+ else :
291+ q = q * self .scale
292+ attn = q @ k .transpose (- 2 , - 1 )
293+ attn = attn .softmax (dim = - 1 )
294+ attn = self .attn_drop (attn )
295+ x_attn = attn @ v
296+ x_attn = x_attn .transpose (1 , 2 ).reshape (B , N , C )
297+ x_attn = self .attn_out_proj (x_attn )
298+
299+ # MLP activation, dropout, fc2
300+ x_mlp = self .mlp_act (x_mlp )
301+ x_mlp = self .mlp_drop (x_mlp )
302+ x_mlp = self .mlp_out_proj (x_mlp )
303+
304+ # Add residual w/ drop path & layer scale applied
305+ y = self .drop_path (self .ls (x_attn + x_mlp ))
306+ x = x + y
262307 return x
263308
264309
@@ -1249,6 +1294,7 @@ def _cfg(url='', **kwargs):
12491294 hf_hub_id = 'timm/' ,
12501295 input_size = (3 , 240 , 240 ), crop_pct = 0.95 , num_classes = 21843 ),
12511296
1297+ 'vit_base_patch16_xp_224.untrained' : _cfg (url = '' ),
12521298 'vit_large_patch14_xp_224.untrained' : _cfg (url = '' ),
12531299 'vit_huge_patch14_xp_224.untrained' : _cfg (url = '' ),
12541300})
@@ -1750,6 +1796,19 @@ def flexivit_large(pretrained=False, **kwargs):
17501796 return model
17511797
17521798
1799+ @register_model
1800+ def vit_base_patch16_xp_224 (pretrained = False , ** kwargs ):
1801+ """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
1802+ """
1803+ model_kwargs = dict (
1804+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , pre_norm = True , no_embed_class = True ,
1805+ norm_layer = RmsNorm , block_fn = ParallelScalingBlock , qkv_bias = False , qk_norm = True ,
1806+ )
1807+ model = _create_vision_transformer (
1808+ 'vit_base_patch16_xp_224' , pretrained = pretrained , ** dict (model_kwargs , ** kwargs ))
1809+ return model
1810+
1811+
17531812@register_model
17541813def vit_large_patch14_xp_224 (pretrained = False , ** kwargs ):
17551814 """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
0 commit comments