77
88Hacked together by / Copyright 2021 Ross Wightman
99"""
10- from typing import Union , Tuple
10+ from typing import Optional , Union , Tuple
1111
1212import torch
1313import torch .nn as nn
1414
15+ from . config import use_fused_attn
1516from .helpers import to_2tuple
17+ from .pos_embed import resample_abs_pos_embed
1618from .pos_embed_sincos import apply_rot_embed , RotaryEmbedding
1719from .weight_init import trunc_normal_
1820
@@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module):
2729 NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
2830 train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
2931 """
32+ fused_attn : torch .jit .Final [bool ]
33+
3034 def __init__ (
3135 self ,
3236 in_features : int ,
33- out_features : int = None ,
34- embed_dim : int = None ,
35- num_heads : int = 4 ,
37+ out_features : Optional [int ] = None ,
38+ ref_feat_size : Union [int , Tuple [int , int ]] = 7 ,
39+ embed_dim : Optional [int ] = None ,
40+ head_dim : Optional [int ] = 64 ,
41+ num_heads : Optional [int ] = None ,
3642 qkv_bias : bool = True ,
43+ qkv_separate : bool = False ,
3744 ):
3845 super ().__init__ ()
3946 embed_dim = embed_dim or in_features
40- out_features = out_features or in_features
41- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
42- self .proj = nn .Linear (embed_dim , out_features )
47+ self .in_features = in_features
48+ self .out_features = out_features or in_features
49+ ref_feat_size = to_2tuple (ref_feat_size )
50+ if num_heads is not None :
51+ assert embed_dim % num_heads == 0
52+ head_dim = embed_dim // num_heads
53+ else :
54+ assert embed_dim % head_dim == 0
55+ num_heads = embed_dim // head_dim
4356 self .num_heads = num_heads
44- assert embed_dim % num_heads == 0
45- self .head_dim = embed_dim // num_heads
57+ self .head_dim = head_dim
4658 self .scale = self .head_dim ** - 0.5
47- self .pos_embed = RotaryEmbedding (self .head_dim )
48-
49- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
50- nn .init .zeros_ (self .qkv .bias )
59+ self .fused_attn = use_fused_attn ()
60+
61+ if qkv_separate :
62+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
63+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
64+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
65+ self .qkv = None
66+ else :
67+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
68+ self .proj = nn .Linear (embed_dim , self .out_features )
69+ self .pos_embed = RotaryEmbedding (self .head_dim , in_pixels = False , ref_feat_shape = ref_feat_size )
70+
71+ def init_weights (self , zero_init_last : bool = False ):
72+ if self .qkv is None :
73+ in_features = self .q .in_features
74+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
75+ nn .init .zeros_ (self .q .bias )
76+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
77+ nn .init .zeros_ (self .k .bias )
78+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
79+ nn .init .zeros_ (self .v .bias )
80+ else :
81+ in_features = self .qkv .in_features
82+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
83+ nn .init .zeros_ (self .qkv .bias )
5184
5285 def forward (self , x ):
5386 B , _ , H , W = x .shape
5487 N = H * W
55- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
56-
88+ x = x .flatten (2 ).transpose (1 , 2 )
5789 x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
58-
59- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
60- q , k , v = x [0 ], x [1 ], x [2 ]
61-
62- qc , q = q [:, :, :1 ], q [:, :, 1 :]
63- sin_emb , cos_emb = self .pos_embed .get_embed ((H , W ))
64- q = apply_rot_embed (q , sin_emb , cos_emb )
65- q = torch .cat ([qc , q ], dim = 2 )
66-
67- kc , k = k [:, :, :1 ], k [:, :, 1 :]
68- k = apply_rot_embed (k , sin_emb , cos_emb )
69- k = torch .cat ([kc , k ], dim = 2 )
70-
71- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
72- attn = attn .softmax (dim = - 1 )
73-
74- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
90+ if self .qkv is None :
91+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
92+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
93+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
94+ else :
95+ x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
96+ q , k , v = x .unbind (0 )
97+
98+ rse , rce = self .pos_embed .get_embed ((H , W ))
99+ q = torch .cat ([q [:, :, :1 , :], apply_rot_embed (q [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
100+ k = torch .cat ([k [:, :, :1 , :], apply_rot_embed (k [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
101+
102+ if self .fused_attn :
103+ x = nn .functional .scaled_dot_product_attention (q , k , v )
104+ else :
105+ q = q * self .scale
106+ attn = q @ k .transpose (- 2 , - 1 )
107+ attn = attn .softmax (dim = - 1 )
108+ x = attn @ v
109+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
75110 x = self .proj (x )
76111 return x [:, 0 ]
77112
@@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module):
85120
86121 NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87122 """
123+ fused_attn : torch .jit .Final [bool ]
124+
88125 def __init__ (
89126 self ,
90127 in_features : int ,
91- feat_size : Union [int , Tuple [int , int ]],
92- out_features : int = None ,
93- embed_dim : int = None ,
94- num_heads : int = 4 ,
128+ feat_size : Union [int , Tuple [int , int ]] = 7 ,
129+ out_features : Optional [int ] = None ,
130+ embed_dim : Optional [int ] = None ,
131+ head_dim : Optional [int ] = 64 ,
132+ num_heads : Optional [int ] = None ,
95133 qkv_bias : bool = True ,
134+ qkv_separate : bool = False ,
96135 ):
97136 super ().__init__ ()
98-
99137 embed_dim = embed_dim or in_features
100- out_features = out_features or in_features
101- assert embed_dim % num_heads == 0
138+ self .in_features = in_features
139+ self .out_features = out_features or in_features
140+ if num_heads is not None :
141+ assert embed_dim % num_heads == 0
142+ head_dim = embed_dim // num_heads
143+ else :
144+ assert embed_dim % head_dim == 0
145+ num_heads = embed_dim // head_dim
102146 self .feat_size = to_2tuple (feat_size )
103- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
104- self .proj = nn .Linear (embed_dim , out_features )
147+ self .seq_len = self .feat_size [0 ] * self .feat_size [1 ]
105148 self .num_heads = num_heads
106- self .head_dim = embed_dim // num_heads
149+ self .head_dim = head_dim
107150 self .scale = self .head_dim ** - 0.5
108-
109- spatial_dim = self .feat_size [0 ] * self .feat_size [1 ]
110- self .pos_embed = nn .Parameter (torch .zeros (spatial_dim + 1 , in_features ))
151+ self .fused_attn = use_fused_attn ()
152+
153+ if qkv_separate :
154+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
155+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
156+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
157+ self .qkv = None
158+ else :
159+ self .q = self .k = self .v = None
160+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
161+ self .proj = nn .Linear (embed_dim , self .out_features )
162+ self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
163+
164+ self .init_weights ()
165+
166+ def init_weights (self , zero_init_last : bool = False ):
167+ if self .qkv is None :
168+ in_features = self .q .in_features
169+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
170+ nn .init .zeros_ (self .q .bias )
171+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
172+ nn .init .zeros_ (self .k .bias )
173+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
174+ nn .init .zeros_ (self .v .bias )
175+ else :
176+ in_features = self .qkv .in_features
177+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
178+ nn .init .zeros_ (self .qkv .bias )
111179 trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
112- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
113- nn .init .zeros_ (self .qkv .bias )
114180
115181 def forward (self , x ):
116182 B , _ , H , W = x .shape
117183 N = H * W
118- assert self .feat_size [0 ] == H
119- assert self .feat_size [1 ] == W
120- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
184+ x = x .flatten (2 ).transpose (1 , 2 )
121185 x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
122- x = x + self .pos_embed .unsqueeze (0 ).to (x .dtype )
123-
124- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
125- q , k , v = x [0 ], x [1 ], x [2 ]
126- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
127- attn = attn .softmax (dim = - 1 )
128-
129- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
186+ if self .seq_len != N :
187+ pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
188+ else :
189+ pos_embed = self .pos_embed .unsqueeze (0 ).to (x .dtype )
190+ x = x + pos_embed
191+
192+ if self .qkv is None :
193+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
194+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
195+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
196+ else :
197+ x = self .qkv (x ).reshape (B , - 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
198+ q , k , v = x .unbind (0 )
199+
200+ if self .fused_attn :
201+ x = nn .functional .scaled_dot_product_attention (q , k , v )
202+ else :
203+ q = q * self .scale
204+ attn = q @ k .transpose (- 2 , - 1 )
205+ attn = attn .softmax (dim = - 1 )
206+ x = attn @ v
207+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
130208 x = self .proj (x )
131209 return x [:, 0 ]
0 commit comments