Skip to content

Commit 430ac93

Browse files
Andy-Elizabeth-mouseOceania2018
authored andcommitted
Add Attention support and test it
1 parent ce3ddb2 commit 430ac93

File tree

7 files changed

+800
-0
lines changed

7 files changed

+800
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
namespace Tensorflow.Keras.ArgsDefinition
2+
{
3+
public class AttentionArgs : BaseDenseAttentionArgs
4+
{
5+
6+
/// <summary>
7+
/// If `true`, will create a scalar variable to scale the attention scores.
8+
/// </summary>
9+
public bool use_scale { get; set; } = false;
10+
11+
/// <summary>
12+
/// Function to use to compute attention scores, one of
13+
/// `{"dot", "concat"}`. `"dot"` refers to the dot product between the query
14+
/// and key vectors. `"concat"` refers to the hyperbolic tangent of the
15+
/// concatenation of the query and key vectors.
16+
/// </summary>
17+
public string score_mode { get; set; } = "dot";
18+
19+
}
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
namespace Tensorflow.Keras.ArgsDefinition
2+
{
3+
public class BaseDenseAttentionArgs : LayerArgs
4+
{
5+
6+
/// <summary>
7+
/// Boolean. Set to `true` for decoder self-attention. Adds a mask such
8+
/// that position `i` cannot attend to positions `j > i`. This prevents the
9+
/// flow of information from the future towards the past.
10+
/// </summary>
11+
public bool causal { get; set; } = false;
12+
13+
/// <summary>
14+
/// Float between 0 and 1. Fraction of the units to drop for the
15+
/// attention scores.
16+
/// </summary>
17+
public float dropout { get; set; } = 0f;
18+
19+
}
20+
}

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,15 @@ public List<IVariableV1> weights
275275
weights.AddRange(non_trainable_weights);
276276
return weights;
277277
}
278+
set
279+
{
280+
if (weights.Count() != value.Count()) throw new ValueError(
281+
$"You called `set_weights` on layer \"{this.name}\"" +
282+
$"with a weight list of length {len(value)}, but the layer was " +
283+
$"expecting {len(weights)} weights.");
284+
foreach (var (this_w, v_w) in zip(weights, value))
285+
this_w.assign(v_w, read_value: true);
286+
}
278287
}
279288

280289
public virtual LayerArgs get_config()
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
using static Tensorflow.Binding;
2+
using static Tensorflow.KerasApi;
3+
using System.Collections;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using Tensorflow.Keras.ArgsDefinition;
7+
8+
namespace Tensorflow.Keras.Layers
9+
{
10+
/// <summary>
11+
/// Dot-product attention layer, a.k.a. Luong-style attention.
12+
/// Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
13+
/// shape `[batch_size, Tv, dim]` and `key` tensor of shape
14+
/// `[batch_size, Tv, dim]`. The calculation follows the steps:
15+
/// <para>
16+
/// 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
17+
/// product: `scores = tf.matmul(query, key, transpose_b=True)`.
18+
/// </para>
19+
/// <para>
20+
/// 2. Use scores to calculate a distribution with shape
21+
/// `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
22+
/// </para>
23+
/// <para>
24+
/// 3. Use `distribution` to create a linear combination of `value` with
25+
/// shape `[batch_size, Tq, dim]`:
26+
/// `return tf.matmul(distribution, value)`.
27+
/// </para>
28+
/// </summary>
29+
/// <example> 0
30+
/// <code>
31+
/// //Variable-length int sequences.
32+
/// var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
33+
/// var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32);
34+
/// // Embedding lookup.
35+
/// var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64);
36+
/// // Query embeddings of shape [batch_size, Tq, dimension].
37+
/// var query_embeddings = token_embedding.Apply(query_input);
38+
/// // Value embeddings of shape [batch_size, Tv, dimension].
39+
/// var value_embeddings = token_embedding.Apply(value_input);
40+
/// // CNN layer.
41+
/// var cnn_layer = keras.layers.Conv1D(
42+
/// filters: 100,
43+
/// kernel_size: 4,
44+
/// // Use 'same' padding so outputs have the same shape as inputs.
45+
/// padding: "same");
46+
/// var cnn_layer2 = keras.layers.Conv1D(
47+
/// filters: 100,
48+
/// kernel_size: 4,
49+
/// // Use 'same' padding so outputs have the same shape as inputs.
50+
/// padding: "same");
51+
/// // Query encoding of shape [batch_size, Tq, filters].
52+
/// var query_seq_encoding = cnn_layer.Apply(query_embeddings);
53+
/// // Value encoding of shape [batch_size, Tv, filters].
54+
/// var value_seq_encoding = cnn_layer.Apply(value_embeddings);
55+
/// // Query-value attention of shape [batch_size, Tq, filters].
56+
/// var query_value_attention_seq = keras.layers.Attention().Apply(
57+
/// (query_seq_encoding, value_seq_encoding));
58+
/// // Reduce over the sequence axis to produce encodings of shape
59+
/// // [batch_size, filters].
60+
/// var query_encoding = keras.layers.GlobalAveragePooling1D().Apply(
61+
/// query_seq_encoding);
62+
/// var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply(
63+
/// query_value_attention_seq);
64+
/// // Concatenate query and document encodings to produce a DNN input layer.
65+
/// var input_layer = keras.layers.Concatenate().Apply(
66+
/// (query_encoding, query_value_attention));
67+
/// // Add DNN layers, and create Model.
68+
/// // ...
69+
/// </code>
70+
/// </example>
71+
public class Attention : BaseDenseAttention
72+
{
73+
74+
public IVariableV1 concat_score_weight;
75+
76+
public IVariableV1 scale;
77+
78+
AttentionArgs args;
79+
80+
string score_mode { get => args.score_mode; }
81+
82+
bool use_scale { get => args.use_scale; }
83+
84+
public Attention(AttentionArgs args) : base(args)
85+
{
86+
this.args = args;
87+
if (!new List<string> {
88+
"dot",
89+
"concat"
90+
}.Contains(this.score_mode))
91+
throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]");
92+
}
93+
94+
// Creates variable when `use_scale` is True or `score_mode` is `concat`.
95+
protected override void build(Tensors inputs) {
96+
if (this.use_scale)
97+
this.scale = this.add_weight(name: "scale",
98+
shape: 1,
99+
initializer: tf.ones_initializer,
100+
dtype: this.DType,
101+
trainable: true);
102+
else
103+
this.scale = null;
104+
105+
if (this.score_mode == "concat")
106+
this.concat_score_weight = this.add_weight(name: "concat_score_weight",
107+
shape: 1,
108+
initializer: tf.ones_initializer,
109+
dtype: this.DType,
110+
trainable: true);
111+
else
112+
this.concat_score_weight = null;
113+
base.build(inputs);
114+
}
115+
116+
/// <summary>
117+
/// Calculates attention scores as a query-key dot product.
118+
/// </summary>
119+
/// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param>
120+
/// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param>
121+
/// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns>
122+
public override Tensor _calculate_scores(Tensor query, Tensor key)
123+
{
124+
Tensor scores = null;
125+
if (this.score_mode == "dot")
126+
{
127+
//scores = tf.matmul(query, key, transpose_b: true);
128+
//scores = tf.matmul(tf.squeeze(query),tf.squeeze(key), transpose_b: true);
129+
scores = tf.linalg.einsum("bij,bkj->bik", (query, key));
130+
if (this.scale != null)
131+
scores *= this.scale.AsTensor();
132+
} else if (this.score_mode == "concat") {
133+
// Reshape tensors to enable broadcasting.
134+
// Reshape into [batch_size, Tq, 1, dim].
135+
var q_reshaped = tf.expand_dims(query, axis: -2);
136+
// Reshape into [batch_size, 1, Tv, dim].
137+
var k_reshaped = tf.expand_dims(key, axis: -3);
138+
if (this.scale != null)
139+
scores = this.concat_score_weight.AsTensor() *
140+
tf.reduce_sum(tf.tanh(this.scale.AsTensor() * (q_reshaped + k_reshaped)), axis: -1);
141+
else
142+
scores = this.concat_score_weight.AsTensor() *
143+
tf.reduce_sum(tf.tanh(q_reshaped + k_reshaped), axis: -1);
144+
}
145+
return scores;
146+
}
147+
148+
public override LayerArgs get_config() => this.args;
149+
//var config = new Dictionary<object, object> {
150+
// {
151+
// "use_scale",
152+
// this.use_scale},
153+
// {
154+
// "score_mode",
155+
// this.score_mode}};
156+
//var base_config = base.get_config();
157+
//return new dict(base_config.items().ToList() + config.items().ToList());
158+
}
159+
}

0 commit comments

Comments
 (0)