@@ -55,6 +55,9 @@ public static (string, string, int) _build_attention_equation(int rank, Shape at
5555 {
5656 var target_notation = _CHR_IDX . Substring ( 0 , rank ) ;
5757 // `batch_dims` includes the head dim.
58+ // batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
59+ // Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value
60+ // use IEnumerable.Except instead of np.delete which is unavailable
5861 var batch_dims = range ( rank ) . Except ( attn_axes . as_int_list ( ) . concat ( new [ ] { rank - 1 } ) ) ;
5962 var letter_offset = rank ;
6063 var source_notation = "" ;
@@ -68,14 +71,14 @@ public static (string, string, int) _build_attention_equation(int rank, Shape at
6871 letter_offset += 1 ;
6972 }
7073 }
71- var product_notation = "" . Insert ( 0 , new string ( ( from i in batch_dims
72- select ( char ) ( int ) target_notation [ i ] ) . Concat (
73-
74- from i in attn_axes . as_int_list ( )
75- select ( char ) ( int ) target_notation [ i ] ) . Concat (
76-
77- from i in attn_axes . as_int_list ( )
78- select source_notation [ i ] ) . ToArray ( ) ) ) ;
74+ var product_notation = new string ( ( from i in batch_dims
75+ select target_notation [ i ] ) . Concat (
76+
77+ from i in attn_axes . as_int_list ( )
78+ select target_notation [ i ] ) . Concat (
79+
80+ from i in attn_axes . as_int_list ( )
81+ select source_notation [ i ] ) . ToArray ( ) ) ;
7982 var dot_product_equation = $ "{ source_notation } ,{ target_notation } ->{ product_notation } ";
8083 var attn_scores_rank = product_notation . Count ( ) ;
8184 var combine_equation = $ "{ product_notation } ,{ source_notation } ->{ target_notation } ";
@@ -163,7 +166,7 @@ public void _build_from_signature(Shape query, Shape value, Shape key = null)
163166 this . _value_shape . rank - 1 , bound_dims : 1 , output_dims : 2 ) ;
164167 this . _value_dense = _get_dense ( einsum_equation ,
165168 _get_output_shape ( output_rank - 1 ,
166- ( this . args . NumHeads , this . args . ValueDim ?? - 1 ) ) ,
169+ ( this . args . NumHeads , this . args . ValueDim ?? this . args . KeyDim ) ) ,
167170 this . args . UseBias ? bias_axes : null ,
168171 "value" ) ;
169172 // Builds the attention computations for multi-head dot product attention.
@@ -235,7 +238,7 @@ public Tensors _compute_attention(
235238 // Note: Applying scalar multiply at the smaller end of einsum improves
236239 // XLA performance, but may introduce slight numeric differences in
237240 // the Transformer attention head.
238- query = tf . multiply ( query , 1d / Math . Sqrt ( this . args . KeyDim ) ) ;
241+ query = tf . multiply ( query , 1f / tf . sqrt ( tf . convert_to_tensor ( ( float ) this . args . KeyDim ) ) ) ;
239242 // Take the dot product between "query" and "key" to get the raw
240243 // attention scores.
241244 var attention_scores = tf . linalg . einsum ( this . _dot_product_equation , ( key , query ) ) ;
@@ -273,7 +276,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
273276 _inp = ( inputs [ 0 ] , inputs [ 1 ] ) ;
274277 break ;
275278 case 3 :
276- if ( inputs [ 2 ] . shape [ - 1 ] != inputs [ 0 ] . shape [ - 1 ] )
279+ if ( inputs [ 2 ] . shape [ - 1 ] == inputs [ 1 ] . shape [ - 1 ] )
277280 _inp = new [ ] { inputs [ 0 ] , inputs [ 1 ] , inputs [ 2 ] } ;
278281 else
279282 {
0 commit comments