1+ using Tensorflow . Keras . ArgsDefinition ;
2+ using Tensorflow . Keras . Engine ;
3+ using Tensorflow . NumPy ;
4+ using static Tensorflow . Binding ;
5+ using static Tensorflow . KerasApi ;
6+ using System ;
7+ using System . Linq ;
8+
9+ namespace Tensorflow . Keras . Layers
10+ {
11+ public class MultiHeadAttention : Layer
12+ {
13+ static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz" ;
14+
15+ MultiHeadAttentionArgs args ;
16+ Shape _query_shape = null ;
17+ Shape _key_shape = null ;
18+ Shape _value_shape = null ;
19+ bool _built_from_signature = false ;
20+ EinsumDense _query_dense = null ;
21+ EinsumDense _key_dense = null ;
22+ EinsumDense _value_dense = null ;
23+ EinsumDense _output_dense = null ;
24+ string _dot_product_equation = "" ;
25+ string _combine_equation = "" ;
26+ Softmax _softmax = null ;
27+ Dropout _dropout_layer = null ;
28+
29+ /// <summary>
30+ /// Builds einsum equations for the attention computation.
31+ /// Query, key, value inputs after projection are expected to have the shape as:
32+ /// `(bs, [non-attention dims], [attention dims], num_heads, channels)`.
33+ /// `bs` and `[non-attention dims]` are treated as `[batch dims]`.
34+ ///
35+ /// <para>
36+ /// The attention operations can be generalized:
37+ /// </para>
38+ /// <para>
39+ /// (1) Query-key dot product:
40+ /// `([batch dims], [query attention dims], num_heads, channels), ([batch dims],
41+ /// [key attention dims], num_heads, channels) -> ([batch dim],
42+ /// num_heads, [query attention dims], [key attention dims])`
43+ /// </para><para>
44+ /// (2) Combination:
45+ /// `([batch dims], num_heads, [query attention dims], [key attention dims]),
46+ /// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims],
47+ /// [query attention dims], num_heads, channels)`
48+ /// </para>
49+ /// </summary>
50+ /// <param name="rank">Rank of query, key, value tensors.</param>
51+ /// <param name="attn_axes">List/tuple of axes, `[-1, rank)`,
52+ /// that attention will be applied to.</param>
53+ /// <returns></returns>
54+ public static ( string , string , int ) _build_attention_equation ( int rank , Shape attn_axes )
55+ {
56+ var target_notation = _CHR_IDX . Substring ( 0 , rank ) ;
57+ // `batch_dims` includes the head dim.
58+ var batch_dims = range ( rank ) . Except ( attn_axes . as_int_list ( ) . concat ( new [ ] { rank - 1 } ) ) ;
59+ var letter_offset = rank ;
60+ var source_notation = "" ;
61+ for ( int i = 0 ; i < rank ; i ++ )
62+ {
63+ if ( batch_dims . Contains ( i ) || i == rank - 1 )
64+ source_notation += target_notation [ i ] ;
65+ else
66+ {
67+ source_notation += _CHR_IDX [ letter_offset ] ;
68+ letter_offset += 1 ;
69+ }
70+ }
71+ var product_notation = "" . Insert ( 0 , new string ( ( from i in batch_dims
72+ select ( char ) ( int ) target_notation [ i ] ) . Concat (
73+
74+ from i in attn_axes . as_int_list ( )
75+ select ( char ) ( int ) target_notation [ i ] ) . Concat (
76+
77+ from i in attn_axes . as_int_list ( )
78+ select source_notation [ i ] ) . ToArray ( ) ) ) ;
79+ var dot_product_equation = $ "{ source_notation } ,{ target_notation } ->{ product_notation } ";
80+ var attn_scores_rank = product_notation . Count ( ) ;
81+ var combine_equation = $ "{ product_notation } ,{ source_notation } ->{ target_notation } ";
82+ return ( dot_product_equation , combine_equation , attn_scores_rank ) ;
83+ }
84+
85+ /// <summary>
86+ /// Builds an einsum equation for projections inside multi-head attention.
87+ /// </summary>
88+ public static ( string , string , int ) _build_proj_equation ( int free_dims , int bound_dims , int output_dims )
89+ {
90+ char _char ;
91+ var input_str = "" ;
92+ var kernel_str = "" ;
93+ var output_str = "" ;
94+ var bias_axes = "" ;
95+ var letter_offset = 0 ;
96+ foreach ( var i in range ( free_dims ) )
97+ {
98+ _char = _CHR_IDX [ i + letter_offset ] ;
99+ input_str += _char ;
100+ output_str += _char ;
101+ }
102+ letter_offset += free_dims ;
103+ foreach ( var i in range ( bound_dims ) )
104+ {
105+ _char = _CHR_IDX [ i + letter_offset ] ;
106+ input_str += _char ;
107+ kernel_str += _char ;
108+ }
109+ letter_offset += bound_dims ;
110+ foreach ( var i in range ( output_dims ) )
111+ {
112+ _char = _CHR_IDX [ i + letter_offset ] ;
113+ kernel_str += _char ;
114+ output_str += _char ;
115+ bias_axes += _char ;
116+ }
117+ var equation = $ "{ input_str } ,{ kernel_str } ->{ output_str } ";
118+ return ( equation , bias_axes , output_str . Count ( ) ) ;
119+ }
120+
121+ static Shape _get_output_shape ( int output_rank , Shape known_last_dims )
122+ => ( from _ in range ( output_rank - known_last_dims . rank )
123+ select - 1 ) . Concat ( known_last_dims . as_int_list ( ) ) . ToArray ( ) ;
124+
125+ public MultiHeadAttention ( MultiHeadAttentionArgs args ) : base ( args )
126+ {
127+ this . args = args ;
128+ }
129+
130+ public void _build_from_signature ( Tensor query , Tensor value , Tensor key = null )
131+ => this . _build_from_signature ( query . shape , value . shape , key ? . shape ) ;
132+
133+ public void _build_from_signature ( Shape query , Shape value , Shape key = null )
134+ {
135+ this . _built_from_signature = true ;
136+ this . _query_shape = query ;
137+ this . _value_shape = value ;
138+ if ( key == null )
139+ this . _key_shape = this . _value_shape ;
140+ else
141+ this . _key_shape = key ;
142+ // Any setup work performed only once should happen in an `init_scope`
143+ // to avoid creating symbolic Tensors that will later pollute any eager
144+ // operations.
145+ tf_with ( tf . init_scope ( ) , _ =>
146+ {
147+ var free_dims = this . _query_shape . rank - 1 ;
148+ var ( einsum_equation , bias_axes , output_rank ) = _build_proj_equation (
149+ free_dims , bound_dims : 1 , output_dims : 2 ) ;
150+ this . _query_dense = _get_dense ( einsum_equation ,
151+ _get_output_shape ( output_rank - 1 ,
152+ ( this . args . NumHeads , this . args . KeyDim ) ) ,
153+ this . args . UseBias ? bias_axes : null ,
154+ "query" ) ;
155+ ( einsum_equation , bias_axes , output_rank ) = _build_proj_equation (
156+ this . _key_shape . rank - 1 , bound_dims : 1 , output_dims : 2 ) ;
157+ this . _key_dense = _get_dense ( einsum_equation ,
158+ _get_output_shape ( output_rank - 1 ,
159+ ( this . args . NumHeads , this . args . KeyDim ) ) ,
160+ this . args . UseBias ? bias_axes : null ,
161+ "key" ) ;
162+ ( einsum_equation , bias_axes , output_rank ) = _build_proj_equation (
163+ this . _value_shape . rank - 1 , bound_dims : 1 , output_dims : 2 ) ;
164+ this . _value_dense = _get_dense ( einsum_equation ,
165+ _get_output_shape ( output_rank - 1 ,
166+ ( this . args . NumHeads , this . args . ValueDim ?? - 1 ) ) ,
167+ this . args . UseBias ? bias_axes : null ,
168+ "value" ) ;
169+ // Builds the attention computations for multi-head dot product attention.
170+ // These computations could be wrapped into the keras attention layer once
171+ // it support mult-head einsum computations.
172+ this . _build_attention ( output_rank ) ;
173+ this . _output_dense = _build_output_dense ( free_dims , "attention_output" ) ;
174+ } ) ;
175+ this . StackLayers ( _query_dense , _key_dense , _value_dense , _output_dense ) ;
176+ }
177+
178+ EinsumDense _get_dense ( string equation , Shape output_shape , string bias_axes , string name )
179+ => new EinsumDense ( new EinsumDenseArgs ( )
180+ {
181+ Equation = equation ,
182+ OutputShape = output_shape ,
183+ BiasAxes = bias_axes ,
184+ Name = name ,
185+ KernelInitializer = this . args . KernelInitializer ,
186+ BiasInitializer = this . args . BiasInitializer ,
187+ KernelRegularizer = this . args . KernelRegularizer ,
188+ BiasRegularizer = this . args . BiasRegularizer ,
189+ KernelConstraint = this . args . KernelConstraint ,
190+ BiasConstraint = this . args . BiasConstraint
191+ } ) ;
192+
193+ EinsumDense _build_output_dense ( int free_dims , string name )
194+ {
195+ if ( this . args . OutputShape == null ) this . args . OutputShape = new ( this . _query_shape [ - 1 ] ) ;
196+ var ( einsum_equation , bias_axes , output_rank ) = _build_proj_equation (
197+ free_dims , bound_dims : 2 , output_dims : len ( this . args . OutputShape ) ) ;
198+ return _get_dense ( einsum_equation ,
199+ _get_output_shape ( output_rank - 1 , this . args . OutputShape ) ,
200+ this . args . UseBias ? bias_axes : null ,
201+ name ) ;
202+ }
203+
204+ void _build_attention ( int rank )
205+ {
206+ if ( this . args . AttentionAxis == null )
207+ this . args . AttentionAxis = new ( range ( 1 , rank - 2 ) . ToArray ( ) ) ;
208+ int attn_scores_rank ;
209+ ( this . _dot_product_equation , this . _combine_equation , attn_scores_rank )
210+ = _build_attention_equation ( rank , this . args . AttentionAxis ) ;
211+ var norm_axes = range ( attn_scores_rank - len ( this . args . AttentionAxis ) ,
212+ attn_scores_rank ) . ToArray ( ) ;
213+ this . _softmax = new Softmax ( new SoftmaxArgs { axis = norm_axes } ) ;
214+ this . _dropout_layer = new Dropout ( new DropoutArgs { Rate = this . args . Dropout } ) ;
215+ }
216+
217+ Tensor _masked_softmax ( Tensor attention_scores , Tensor attention_mask = null )
218+ {
219+ if ( attention_mask != null )
220+ {
221+ var mask_expansion_axis = - len ( this . args . AttentionAxis ) * 2 - 1 ;
222+ for ( int i = 0 ; i < len ( attention_scores . shape ) - len ( attention_mask . shape ) ; i ++ )
223+ attention_mask = tf . expand_dims ( attention_mask , axis : mask_expansion_axis ) ;
224+ }
225+ return this . _softmax . Apply ( attention_mask == null ? attention_scores : ( attention_scores , attention_mask ) ) ;
226+ }
227+
228+ public Tensors _compute_attention (
229+ Tensor query ,
230+ Tensor key ,
231+ Tensor value ,
232+ Tensor attention_mask = null ,
233+ bool training = false )
234+ {
235+ // Note: Applying scalar multiply at the smaller end of einsum improves
236+ // XLA performance, but may introduce slight numeric differences in
237+ // the Transformer attention head.
238+ query = tf . multiply ( query , 1d / Math . Sqrt ( this . args . KeyDim ) ) ;
239+ // Take the dot product between "query" and "key" to get the raw
240+ // attention scores.
241+ var attention_scores = tf . linalg . einsum ( this . _dot_product_equation , ( key , query ) ) ;
242+ attention_scores = this . _masked_softmax ( attention_scores , attention_mask ) ;
243+ // This is actually dropping out entire tokens to attend to, which might
244+ // seem a bit unusual, but is taken from the original Transformer paper.
245+ var attention_scores_dropout = this . _dropout_layer . Apply ( attention_scores , training : training ) ;
246+ // `context_layer` = [B, T, N, H]
247+ var attention_output = tf . linalg . einsum ( this . _combine_equation , ( attention_scores_dropout , value ) ) ;
248+ return ( attention_output , attention_scores ) ;
249+ }
250+
251+ protected override Tensors Call ( Tensors inputs , Tensor state = null , bool ? training = null )
252+ {
253+ Tensors _inp ;
254+ Tensor _mask = null ;
255+
256+ int count = inputs . Count ( ) ;
257+ if ( count < 2 || count > 5 ) throw new ValueError (
258+ $ "{ this . name } layer accepts inputs list of length from 2 to 5, " +
259+ $ "namely [query, value, (key), (attention_mask), (return_attention_scores)]." +
260+ $ "Received length: { count } .") ;
261+
262+ bool has_bool = inputs [ count - 1 ] . dtype == TF_DataType . TF_BOOL ;
263+ bool return_attention_scores = false ;
264+ if ( has_bool )
265+ {
266+ return_attention_scores = ( bool ) inputs [ count - 1 ] ;
267+ count -- ;
268+ }
269+
270+ switch ( count )
271+ {
272+ case 2 :
273+ _inp = ( inputs [ 0 ] , inputs [ 1 ] ) ;
274+ break ;
275+ case 3 :
276+ if ( inputs [ 2 ] . shape [ - 1 ] != inputs [ 0 ] . shape [ - 1 ] )
277+ _inp = new [ ] { inputs [ 0 ] , inputs [ 1 ] , inputs [ 2 ] } ;
278+ else
279+ {
280+ _inp = ( inputs [ 0 ] , inputs [ 1 ] ) ;
281+ _mask = inputs [ 2 ] ;
282+ }
283+ break ;
284+ case 4 :
285+ _inp = new [ ] { inputs [ 0 ] , inputs [ 1 ] , inputs [ 2 ] } ;
286+ _mask = inputs [ 3 ] ;
287+ break ;
288+ default :
289+ throw new ValueError ( ) ; //TODO:Add discriptions for this err
290+ }
291+
292+ return call ( _inp , _mask , training , return_attention_scores ) ;
293+ }
294+
295+ protected Tensors call ( Tensors inputs ,
296+ Tensor attention_mask ,
297+ bool ? training = null ,
298+ bool return_attention_scores = false )
299+ {
300+ var ( query , value , key ) = ( inputs [ 0 ] , inputs [ 1 ] , inputs . Length == 3 ? inputs [ 2 ] : null ) ;
301+ if ( ! this . _built_from_signature )
302+ this . _build_from_signature ( query : query , value : value , key : key ) ;
303+ if ( key == null )
304+ key = value ;
305+
306+ // TODO: Add RaggedTensor support
307+ //var query_is_ragged = query is tf.RaggedTensor;
308+ //if (query_is_ragged)
309+ //{
310+ // var query_lengths = query.nested_row_lengths();
311+ // query = query.to_tensor();
312+ //}
313+ //var key_is_ragged = key is tf.RaggedTensor;
314+ //var value_is_ragged = value is tf.RaggedTensor;
315+ //if (key_is_ragged && value_is_ragged)
316+ //{
317+ // // Ensure they have the same shape.
318+ // var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape());
319+ // key = key.to_tensor(shape: bounding_shape);
320+ // value = value.to_tensor(shape: bounding_shape);
321+ //}
322+ //else if (key_is_ragged)
323+ //{
324+ // key = key.to_tensor(shape: tf.shape(value));
325+ //}
326+ //else if (value_is_ragged)
327+ //{
328+ // value = value.to_tensor(shape: tf.shape(key));
329+ //}
330+
331+ // N = `num_attention_heads`
332+ // H = `size_per_head`
333+ // `query` = [B, T, N ,H]
334+ query = this . _query_dense . Apply ( query ) ;
335+ // `key` = [B, S, N, H]
336+ key = this . _key_dense . Apply ( key ) ;
337+ // `value` = [B, S, N, H]
338+ value = this . _value_dense . Apply ( value ) ;
339+ var ( attention_output , attention_scores ) = this . _compute_attention ( query , key , value , attention_mask , training ?? false ) ;
340+ attention_output = this . _output_dense . Apply ( attention_output ) ;
341+
342+ //if (query_is_ragged)
343+ //{
344+ // attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths);
345+ //}
346+
347+ if ( return_attention_scores )
348+ return ( attention_output , attention_scores ) ;
349+ return attention_output ;
350+ }
351+ }
352+ }
0 commit comments