Skip to content

Commit 2e94ed3

Browse files
Andy-Elizabeth-mouseOceania2018
authored andcommitted
multi-head attention
1 parent 9a57947 commit 2e94ed3

File tree

5 files changed

+409
-161
lines changed

5 files changed

+409
-161
lines changed

src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ public Softmax ( SoftmaxArgs args ) : base(args) {
1212
axis = args.axis;
1313
}
1414
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
15-
Tensor x = inputs;
15+
Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9)
16+
: inputs;
1617
Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true)));
1718
Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true);
1819
return tf.div(e, s);

src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
120120

121121
int count = inputs.Count();
122122
if (count < 2 || count > 6) throw new ValueError(
123-
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
123+
$"{ this.name } layer accepts inputs list of length from 2 to 6, " +
124124
$"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." +
125125
$"Received length: {count}.");
126126

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
using Tensorflow.Keras.ArgsDefinition;
2+
using Tensorflow.Keras.Engine;
3+
using Tensorflow.NumPy;
4+
using static Tensorflow.Binding;
5+
using static Tensorflow.KerasApi;
6+
using System;
7+
using System.Linq;
8+
9+
namespace Tensorflow.Keras.Layers
10+
{
11+
public class MultiHeadAttention : Layer
12+
{
13+
static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz";
14+
15+
MultiHeadAttentionArgs args;
16+
Shape _query_shape = null;
17+
Shape _key_shape = null;
18+
Shape _value_shape = null;
19+
bool _built_from_signature = false;
20+
EinsumDense _query_dense = null;
21+
EinsumDense _key_dense = null;
22+
EinsumDense _value_dense = null;
23+
EinsumDense _output_dense = null;
24+
string _dot_product_equation = "";
25+
string _combine_equation = "";
26+
Softmax _softmax = null;
27+
Dropout _dropout_layer = null;
28+
29+
/// <summary>
30+
/// Builds einsum equations for the attention computation.
31+
/// Query, key, value inputs after projection are expected to have the shape as:
32+
/// `(bs, [non-attention dims], [attention dims], num_heads, channels)`.
33+
/// `bs` and `[non-attention dims]` are treated as `[batch dims]`.
34+
///
35+
/// <para>
36+
/// The attention operations can be generalized:
37+
/// </para>
38+
/// <para>
39+
/// (1) Query-key dot product:
40+
/// `([batch dims], [query attention dims], num_heads, channels), ([batch dims],
41+
/// [key attention dims], num_heads, channels) -> ([batch dim],
42+
/// num_heads, [query attention dims], [key attention dims])`
43+
/// </para><para>
44+
/// (2) Combination:
45+
/// `([batch dims], num_heads, [query attention dims], [key attention dims]),
46+
/// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims],
47+
/// [query attention dims], num_heads, channels)`
48+
/// </para>
49+
/// </summary>
50+
/// <param name="rank">Rank of query, key, value tensors.</param>
51+
/// <param name="attn_axes">List/tuple of axes, `[-1, rank)`,
52+
/// that attention will be applied to.</param>
53+
/// <returns></returns>
54+
public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes)
55+
{
56+
var target_notation = _CHR_IDX.Substring(0, rank);
57+
// `batch_dims` includes the head dim.
58+
var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 }));
59+
var letter_offset = rank;
60+
var source_notation = "";
61+
for (int i = 0; i < rank; i++)
62+
{
63+
if (batch_dims.Contains(i) || i == rank - 1)
64+
source_notation += target_notation[i];
65+
else
66+
{
67+
source_notation += _CHR_IDX[letter_offset];
68+
letter_offset += 1;
69+
}
70+
}
71+
var product_notation = "".Insert(0, new string((from i in batch_dims
72+
select (char)(int)target_notation[i]).Concat(
73+
74+
from i in attn_axes.as_int_list()
75+
select (char)(int)target_notation[i]).Concat(
76+
77+
from i in attn_axes.as_int_list()
78+
select source_notation[i]).ToArray()));
79+
var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}";
80+
var attn_scores_rank = product_notation.Count();
81+
var combine_equation = $"{product_notation},{source_notation}->{target_notation}";
82+
return (dot_product_equation, combine_equation, attn_scores_rank);
83+
}
84+
85+
/// <summary>
86+
/// Builds an einsum equation for projections inside multi-head attention.
87+
/// </summary>
88+
public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims)
89+
{
90+
char _char;
91+
var input_str = "";
92+
var kernel_str = "";
93+
var output_str = "";
94+
var bias_axes = "";
95+
var letter_offset = 0;
96+
foreach (var i in range(free_dims))
97+
{
98+
_char = _CHR_IDX[i + letter_offset];
99+
input_str += _char;
100+
output_str += _char;
101+
}
102+
letter_offset += free_dims;
103+
foreach (var i in range(bound_dims))
104+
{
105+
_char = _CHR_IDX[i + letter_offset];
106+
input_str += _char;
107+
kernel_str += _char;
108+
}
109+
letter_offset += bound_dims;
110+
foreach (var i in range(output_dims))
111+
{
112+
_char = _CHR_IDX[i + letter_offset];
113+
kernel_str += _char;
114+
output_str += _char;
115+
bias_axes += _char;
116+
}
117+
var equation = $"{input_str},{kernel_str}->{output_str}";
118+
return (equation, bias_axes, output_str.Count());
119+
}
120+
121+
static Shape _get_output_shape(int output_rank, Shape known_last_dims)
122+
=> (from _ in range(output_rank - known_last_dims.rank)
123+
select -1).Concat(known_last_dims.as_int_list()).ToArray();
124+
125+
public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args)
126+
{
127+
this.args = args;
128+
}
129+
130+
public void _build_from_signature(Tensor query, Tensor value, Tensor key = null)
131+
=> this._build_from_signature(query.shape, value.shape, key?.shape);
132+
133+
public void _build_from_signature(Shape query, Shape value, Shape key = null)
134+
{
135+
this._built_from_signature = true;
136+
this._query_shape = query;
137+
this._value_shape = value;
138+
if (key == null)
139+
this._key_shape = this._value_shape;
140+
else
141+
this._key_shape = key;
142+
// Any setup work performed only once should happen in an `init_scope`
143+
// to avoid creating symbolic Tensors that will later pollute any eager
144+
// operations.
145+
tf_with(tf.init_scope(), _ =>
146+
{
147+
var free_dims = this._query_shape.rank - 1;
148+
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
149+
free_dims, bound_dims: 1, output_dims: 2);
150+
this._query_dense = _get_dense(einsum_equation,
151+
_get_output_shape(output_rank - 1,
152+
(this.args.NumHeads, this.args.KeyDim)),
153+
this.args.UseBias ? bias_axes : null,
154+
"query");
155+
(einsum_equation, bias_axes, output_rank) = _build_proj_equation(
156+
this._key_shape.rank - 1, bound_dims: 1, output_dims: 2);
157+
this._key_dense = _get_dense(einsum_equation,
158+
_get_output_shape(output_rank - 1,
159+
(this.args.NumHeads, this.args.KeyDim)),
160+
this.args.UseBias ? bias_axes : null,
161+
"key");
162+
(einsum_equation, bias_axes, output_rank) = _build_proj_equation(
163+
this._value_shape.rank - 1, bound_dims: 1, output_dims: 2);
164+
this._value_dense = _get_dense(einsum_equation,
165+
_get_output_shape(output_rank - 1,
166+
(this.args.NumHeads, this.args.ValueDim ?? -1)),
167+
this.args.UseBias ? bias_axes : null,
168+
"value");
169+
// Builds the attention computations for multi-head dot product attention.
170+
// These computations could be wrapped into the keras attention layer once
171+
// it support mult-head einsum computations.
172+
this._build_attention(output_rank);
173+
this._output_dense = _build_output_dense(free_dims, "attention_output");
174+
});
175+
this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense);
176+
}
177+
178+
EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name)
179+
=> new EinsumDense(new EinsumDenseArgs()
180+
{
181+
Equation = equation,
182+
OutputShape = output_shape,
183+
BiasAxes = bias_axes,
184+
Name = name,
185+
KernelInitializer = this.args.KernelInitializer,
186+
BiasInitializer = this.args.BiasInitializer,
187+
KernelRegularizer = this.args.KernelRegularizer,
188+
BiasRegularizer = this.args.BiasRegularizer,
189+
KernelConstraint = this.args.KernelConstraint,
190+
BiasConstraint = this.args.BiasConstraint
191+
});
192+
193+
EinsumDense _build_output_dense(int free_dims, string name)
194+
{
195+
if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]);
196+
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
197+
free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape));
198+
return _get_dense(einsum_equation,
199+
_get_output_shape(output_rank - 1, this.args.OutputShape),
200+
this.args.UseBias ? bias_axes : null,
201+
name);
202+
}
203+
204+
void _build_attention(int rank)
205+
{
206+
if (this.args.AttentionAxis == null)
207+
this.args.AttentionAxis = new(range(1, rank - 2).ToArray());
208+
int attn_scores_rank;
209+
(this._dot_product_equation, this._combine_equation, attn_scores_rank)
210+
= _build_attention_equation(rank, this.args.AttentionAxis);
211+
var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis),
212+
attn_scores_rank).ToArray();
213+
this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes });
214+
this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout });
215+
}
216+
217+
Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null)
218+
{
219+
if(attention_mask != null)
220+
{
221+
var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1;
222+
for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++)
223+
attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis);
224+
}
225+
return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask));
226+
}
227+
228+
public Tensors _compute_attention(
229+
Tensor query,
230+
Tensor key,
231+
Tensor value,
232+
Tensor attention_mask = null,
233+
bool training = false)
234+
{
235+
// Note: Applying scalar multiply at the smaller end of einsum improves
236+
// XLA performance, but may introduce slight numeric differences in
237+
// the Transformer attention head.
238+
query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim));
239+
// Take the dot product between "query" and "key" to get the raw
240+
// attention scores.
241+
var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query));
242+
attention_scores = this._masked_softmax(attention_scores, attention_mask);
243+
// This is actually dropping out entire tokens to attend to, which might
244+
// seem a bit unusual, but is taken from the original Transformer paper.
245+
var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training);
246+
// `context_layer` = [B, T, N, H]
247+
var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value));
248+
return (attention_output, attention_scores);
249+
}
250+
251+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
252+
{
253+
Tensors _inp;
254+
Tensor _mask = null;
255+
256+
int count = inputs.Count();
257+
if (count < 2 || count > 5) throw new ValueError(
258+
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
259+
$"namely [query, value, (key), (attention_mask), (return_attention_scores)]." +
260+
$"Received length: {count}.");
261+
262+
bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
263+
bool return_attention_scores = false;
264+
if (has_bool)
265+
{
266+
return_attention_scores = (bool)inputs[count - 1];
267+
count--;
268+
}
269+
270+
switch (count)
271+
{
272+
case 2:
273+
_inp = (inputs[0], inputs[1]);
274+
break;
275+
case 3:
276+
if (inputs[2].shape[-1] != inputs[0].shape[-1])
277+
_inp = new[] { inputs[0], inputs[1], inputs[2] };
278+
else
279+
{
280+
_inp = (inputs[0], inputs[1]);
281+
_mask = inputs[2];
282+
}
283+
break;
284+
case 4:
285+
_inp = new[] { inputs[0], inputs[1], inputs[2] };
286+
_mask = inputs[3];
287+
break;
288+
default:
289+
throw new ValueError(); //TODO:Add discriptions for this err
290+
}
291+
292+
return call(_inp, _mask, training, return_attention_scores);
293+
}
294+
295+
protected Tensors call(Tensors inputs,
296+
Tensor attention_mask,
297+
bool? training = null,
298+
bool return_attention_scores = false)
299+
{
300+
var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null);
301+
if (!this._built_from_signature)
302+
this._build_from_signature(query: query, value: value, key: key);
303+
if (key == null)
304+
key = value;
305+
306+
// TODO: Add RaggedTensor support
307+
//var query_is_ragged = query is tf.RaggedTensor;
308+
//if (query_is_ragged)
309+
//{
310+
// var query_lengths = query.nested_row_lengths();
311+
// query = query.to_tensor();
312+
//}
313+
//var key_is_ragged = key is tf.RaggedTensor;
314+
//var value_is_ragged = value is tf.RaggedTensor;
315+
//if (key_is_ragged && value_is_ragged)
316+
//{
317+
// // Ensure they have the same shape.
318+
// var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape());
319+
// key = key.to_tensor(shape: bounding_shape);
320+
// value = value.to_tensor(shape: bounding_shape);
321+
//}
322+
//else if (key_is_ragged)
323+
//{
324+
// key = key.to_tensor(shape: tf.shape(value));
325+
//}
326+
//else if (value_is_ragged)
327+
//{
328+
// value = value.to_tensor(shape: tf.shape(key));
329+
//}
330+
331+
// N = `num_attention_heads`
332+
// H = `size_per_head`
333+
// `query` = [B, T, N ,H]
334+
query = this._query_dense.Apply(query);
335+
// `key` = [B, S, N, H]
336+
key = this._key_dense.Apply(key);
337+
// `value` = [B, S, N, H]
338+
value = this._value_dense.Apply(value);
339+
var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false);
340+
attention_output = this._output_dense.Apply(attention_output);
341+
342+
//if (query_is_ragged)
343+
//{
344+
// attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths);
345+
//}
346+
347+
if (return_attention_scores)
348+
return (attention_output, attention_scores);
349+
return attention_output;
350+
}
351+
}
352+
}

0 commit comments

Comments
 (0)