77
88Hacked together by / Copyright 2021 Ross Wightman
99"""
10- from typing import Union , Tuple
10+ from typing import Optional , Union , Tuple
1111
1212import torch
1313import torch .nn as nn
1414
15+ from . config import use_fused_attn
1516from .helpers import to_2tuple
17+ from .pos_embed import resample_abs_pos_embed
1618from .pos_embed_sincos import apply_rot_embed , RotaryEmbedding
1719from .weight_init import trunc_normal_
1820
@@ -27,53 +29,122 @@ class RotAttentionPool2d(nn.Module):
2729 NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
2830 train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
2931 """
32+ fused_attn : torch .jit .Final [bool ]
33+
3034 def __init__ (
3135 self ,
3236 in_features : int ,
33- out_features : int = None ,
34- embed_dim : int = None ,
35- num_heads : int = 4 ,
37+ out_features : Optional [int ] = None ,
38+ ref_feat_size : Union [int , Tuple [int , int ]] = 7 ,
39+ embed_dim : Optional [int ] = None ,
40+ head_dim : Optional [int ] = 64 ,
41+ num_heads : Optional [int ] = None ,
3642 qkv_bias : bool = True ,
43+ qkv_separate : bool = False ,
44+ pool_type : str = 'token' ,
45+ class_token : bool = False ,
46+ drop_rate : float = 0. ,
3747 ):
3848 super ().__init__ ()
39- embed_dim = embed_dim or in_features
40- out_features = out_features or in_features
41- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
42- self .proj = nn .Linear (embed_dim , out_features )
49+ assert pool_type in ('' , 'token' )
50+ self .embed_dim = embed_dim = embed_dim or in_features
51+ self .in_features = in_features
52+ self .out_features = out_features or in_features
53+ ref_feat_size = to_2tuple (ref_feat_size )
54+ if num_heads is not None :
55+ assert embed_dim % num_heads == 0
56+ head_dim = embed_dim // num_heads
57+ else :
58+ assert embed_dim % head_dim == 0
59+ num_heads = embed_dim // head_dim
4360 self .num_heads = num_heads
44- assert embed_dim % num_heads == 0
45- self .head_dim = embed_dim // num_heads
61+ self . head_dim = head_dim
62+ self .pool_type = pool_type . lower ()
4663 self .scale = self .head_dim ** - 0.5
47- self .pos_embed = RotaryEmbedding ( self . head_dim )
64+ self .fused_attn = use_fused_attn ( )
4865
49- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
50- nn .init .zeros_ (self .qkv .bias )
66+ if class_token :
67+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
68+ else :
69+ self .cls_token = None
5170
52- def forward (self , x ):
53- B , _ , H , W = x .shape
54- N = H * W
55- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
71+ if qkv_separate :
72+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
73+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
74+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
75+ self .qkv = None
76+ else :
77+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
78+ self .drop = nn .Dropout (drop_rate )
79+ self .proj = nn .Linear (embed_dim , self .out_features )
80+ self .pos_embed = RotaryEmbedding (self .head_dim , in_pixels = False , ref_feat_shape = ref_feat_size )
5681
57- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
82+ def init_weights (self , zero_init_last : bool = False ):
83+ if self .qkv is None :
84+ in_features = self .q .in_features
85+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
86+ nn .init .zeros_ (self .q .bias )
87+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
88+ nn .init .zeros_ (self .k .bias )
89+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
90+ nn .init .zeros_ (self .v .bias )
91+ else :
92+ in_features = self .qkv .in_features
93+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
94+ nn .init .zeros_ (self .qkv .bias )
5895
59- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
60- q , k , v = x [0 ], x [1 ], x [2 ]
96+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
97+ # NOTE: this module is being used as a head, so need compatible reset()
98+ if pool_type is not None :
99+ assert pool_type in ('' , 'token' )
100+ self .pool_type = pool_type
101+ if num_classes is not None :
102+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
103+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
61104
62- qc , q = q [:, :, :1 ], q [:, :, 1 :]
63- sin_emb , cos_emb = self .pos_embed .get_embed ((H , W ))
64- q = apply_rot_embed (q , sin_emb , cos_emb )
65- q = torch .cat ([qc , q ], dim = 2 )
105+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
106+ if self .pool_type == 'token' :
107+ x = x [:, 0 ]
108+ else :
109+ # if not pooled, return spatial output without token
110+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
111+ return x
66112
67- kc , k = k [:, :, :1 ], k [:, :, 1 :]
68- k = apply_rot_embed (k , sin_emb , cos_emb )
69- k = torch .cat ([kc , k ], dim = 2 )
113+ def forward (self , x , pre_logits : bool = False ):
114+ B , _ , H , W = x .shape
115+ N = H * W
116+ x = x .flatten (2 ).transpose (1 , 2 )
117+ if self .cls_token is None :
118+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
119+ else :
120+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
121+ if self .qkv is None :
122+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
123+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
124+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
125+ else :
126+ x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
127+ q , k , v = x .unbind (0 )
70128
71- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
72- attn = attn .softmax (dim = - 1 )
129+ rse , rce = self .pos_embed .get_embed ((H , W ))
130+ q = torch .cat ([q [:, :, :1 , :], apply_rot_embed (q [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
131+ k = torch .cat ([k [:, :, :1 , :], apply_rot_embed (k [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
73132
74- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
133+ if self .fused_attn :
134+ x = nn .functional .scaled_dot_product_attention (q , k , v )
135+ else :
136+ q = q * self .scale
137+ attn = q @ k .transpose (- 2 , - 1 )
138+ attn = attn .softmax (dim = - 1 )
139+ x = attn @ v
140+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
141+ x = self .drop (x )
142+ if pre_logits :
143+ x = self ._pool (x , H , W )
144+ return x
75145 x = self .proj (x )
76- return x [:, 0 ]
146+ x = self ._pool (x , H , W )
147+ return x
77148
78149
79150class AttentionPool2d (nn .Module ):
@@ -85,47 +156,123 @@ class AttentionPool2d(nn.Module):
85156
86157 NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87158 """
159+ fused_attn : torch .jit .Final [bool ]
160+
88161 def __init__ (
89162 self ,
90163 in_features : int ,
91- feat_size : Union [int , Tuple [int , int ]],
92- out_features : int = None ,
93- embed_dim : int = None ,
94- num_heads : int = 4 ,
164+ feat_size : Union [int , Tuple [int , int ]] = 7 ,
165+ out_features : Optional [int ] = None ,
166+ embed_dim : Optional [int ] = None ,
167+ head_dim : Optional [int ] = 64 ,
168+ num_heads : Optional [int ] = None ,
95169 qkv_bias : bool = True ,
170+ qkv_separate : bool = False ,
171+ pool_type : str = 'token' ,
172+ class_token : bool = False ,
173+ drop_rate : float = 0. ,
96174 ):
97175 super ().__init__ ()
98-
99- embed_dim = embed_dim or in_features
100- out_features = out_features or in_features
101- assert embed_dim % num_heads == 0
176+ assert pool_type in ('' , 'token' )
177+ self .embed_dim = embed_dim = embed_dim or in_features
178+ self .in_features = in_features
179+ self .out_features = out_features or in_features
180+ if num_heads is not None :
181+ assert embed_dim % num_heads == 0
182+ head_dim = embed_dim // num_heads
183+ else :
184+ assert embed_dim % head_dim == 0
185+ num_heads = embed_dim // head_dim
102186 self .feat_size = to_2tuple (feat_size )
103- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
104- self .proj = nn .Linear (embed_dim , out_features )
187+ self .seq_len = self .feat_size [0 ] * self .feat_size [1 ]
105188 self .num_heads = num_heads
106- self .head_dim = embed_dim // num_heads
189+ self .head_dim = head_dim
190+ self .pool_type = pool_type
107191 self .scale = self .head_dim ** - 0.5
192+ self .fused_attn = use_fused_attn ()
108193
109- spatial_dim = self .feat_size [0 ] * self .feat_size [1 ]
110- self .pos_embed = nn .Parameter (torch .zeros (spatial_dim + 1 , in_features ))
194+ if class_token :
195+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
196+ else :
197+ self .cls_token = None
198+
199+ if qkv_separate :
200+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
201+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
202+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
203+ self .qkv = None
204+ else :
205+ self .q = self .k = self .v = None
206+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
207+ self .drop = nn .Dropout (drop_rate )
208+ self .proj = nn .Linear (embed_dim , self .out_features )
209+ self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
210+
211+ self .init_weights ()
212+
213+ def init_weights (self , zero_init_last : bool = False ):
214+ if self .qkv is None :
215+ in_features = self .q .in_features
216+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
217+ nn .init .zeros_ (self .q .bias )
218+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
219+ nn .init .zeros_ (self .k .bias )
220+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
221+ nn .init .zeros_ (self .v .bias )
222+ else :
223+ in_features = self .qkv .in_features
224+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
225+ nn .init .zeros_ (self .qkv .bias )
111226 trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
112- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
113- nn .init .zeros_ (self .qkv .bias )
114227
115- def forward (self , x ):
228+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
229+ # NOTE: this module is being used as a head, so need compatible reset()
230+ if pool_type is not None :
231+ assert pool_type in ('' , 'token' )
232+ self .pool_type = pool_type
233+ if num_classes is not None :
234+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
235+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
236+
237+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
238+ if self .pool_type == 'token' :
239+ x = x [:, 0 ]
240+ else :
241+ # if not pooled, return spatial output without token
242+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
243+ return x
244+
245+ def forward (self , x , pre_logits : bool = False ):
116246 B , _ , H , W = x .shape
117247 N = H * W
118- assert self .feat_size [0 ] == H
119- assert self .feat_size [1 ] == W
120- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
121- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
122- x = x + self .pos_embed .unsqueeze (0 ).to (x .dtype )
123-
124- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
125- q , k , v = x [0 ], x [1 ], x [2 ]
126- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
127- attn = attn .softmax (dim = - 1 )
128-
129- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
248+ x = x .flatten (2 ).transpose (1 , 2 )
249+ if self .cls_token is None :
250+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
251+ else :
252+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
253+ pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
254+ x = x + pos_embed
255+
256+ if self .qkv is None :
257+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
258+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
259+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
260+ else :
261+ x = self .qkv (x ).reshape (B , - 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
262+ q , k , v = x .unbind (0 )
263+
264+ if self .fused_attn :
265+ x = nn .functional .scaled_dot_product_attention (q , k , v )
266+ else :
267+ q = q * self .scale
268+ attn = q @ k .transpose (- 2 , - 1 )
269+ attn = attn .softmax (dim = - 1 )
270+ x = attn @ v
271+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
272+ x = self .drop (x )
273+ if pre_logits :
274+ x = self ._pool (x , H , W )
275+ return x
130276 x = self .proj (x )
131- return x [:, 0 ]
277+ x = self ._pool (x , H , W )
278+ return x
0 commit comments