Skip to content

Commit 424ac06

Browse files
Andy-Elizabeth-mouseOceania2018
authored andcommitted
fixed bugs
1 parent 2e94ed3 commit 424ac06

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ public static (string, string, int) _build_attention_equation(int rank, Shape at
5555
{
5656
var target_notation = _CHR_IDX.Substring(0, rank);
5757
// `batch_dims` includes the head dim.
58+
// batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
59+
// Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value
60+
// use IEnumerable.Except instead of np.delete which is unavailable
5861
var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 }));
5962
var letter_offset = rank;
6063
var source_notation = "";
@@ -68,14 +71,14 @@ public static (string, string, int) _build_attention_equation(int rank, Shape at
6871
letter_offset += 1;
6972
}
7073
}
71-
var product_notation = "".Insert(0, new string((from i in batch_dims
72-
select (char)(int)target_notation[i]).Concat(
73-
74-
from i in attn_axes.as_int_list()
75-
select (char)(int)target_notation[i]).Concat(
76-
77-
from i in attn_axes.as_int_list()
78-
select source_notation[i]).ToArray()));
74+
var product_notation = new string((from i in batch_dims
75+
select target_notation[i]).Concat(
76+
77+
from i in attn_axes.as_int_list()
78+
select target_notation[i]).Concat(
79+
80+
from i in attn_axes.as_int_list()
81+
select source_notation[i]).ToArray());
7982
var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}";
8083
var attn_scores_rank = product_notation.Count();
8184
var combine_equation = $"{product_notation},{source_notation}->{target_notation}";
@@ -163,7 +166,7 @@ public void _build_from_signature(Shape query, Shape value, Shape key = null)
163166
this._value_shape.rank - 1, bound_dims: 1, output_dims: 2);
164167
this._value_dense = _get_dense(einsum_equation,
165168
_get_output_shape(output_rank - 1,
166-
(this.args.NumHeads, this.args.ValueDim ?? -1)),
169+
(this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)),
167170
this.args.UseBias ? bias_axes : null,
168171
"value");
169172
// Builds the attention computations for multi-head dot product attention.
@@ -235,7 +238,7 @@ public Tensors _compute_attention(
235238
// Note: Applying scalar multiply at the smaller end of einsum improves
236239
// XLA performance, but may introduce slight numeric differences in
237240
// the Transformer attention head.
238-
query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim));
241+
query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim)));
239242
// Take the dot product between "query" and "key" to get the raw
240243
// attention scores.
241244
var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query));
@@ -273,7 +276,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
273276
_inp = (inputs[0], inputs[1]);
274277
break;
275278
case 3:
276-
if (inputs[2].shape[-1] != inputs[0].shape[-1])
279+
if (inputs[2].shape[-1] == inputs[1].shape[-1])
277280
_inp = new[] { inputs[0], inputs[1], inputs[2] };
278281
else
279282
{

src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ public static (Shape, Shape, Shape) _analyze_split_string(Match split_string,
228228
Shape output_shape,
229229
bool left_elided = false)
230230
{
231-
List<long> bias_shape;
231+
List<int> bias_shape;
232232
Dictionary<char, int> output_dim_map;
233233
Dictionary<char, int> input_dim_map;
234234

@@ -275,8 +275,8 @@ public static (Shape, Shape, Shape) _analyze_split_string(Match split_string,
275275
var input_shape_at_dim = input_shape[input_dim_map[dim]];
276276
if (output_dim_map.TryGetValue(dim, out int index))
277277
{
278-
var output_shape_at_dim = output_shape[index];
279-
if (output_shape_at_dim != input_shape_at_dim)
278+
var output_shape_at_dim = _output_shape[index];
279+
if (output_shape_at_dim != -1 && output_shape_at_dim != input_shape_at_dim)
280280
throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " +
281281
$"Input shape is {input_shape_at_dim}, " +
282282
$"and output shape is {output_shape[output_dim_map[dim]]}.");
@@ -299,7 +299,7 @@ public static (Shape, Shape, Shape) _analyze_split_string(Match split_string,
299299
if (input_dim_map.ContainsKey(dim))
300300
weight_shape.append(input_shape[input_dim_map[dim]]);
301301
else if (output_dim_map.ContainsKey(dim))
302-
weight_shape.append(output_shape[output_dim_map[dim]]);
302+
weight_shape.append(_output_shape[output_dim_map[dim]]);
303303
else throw new ValueError($"Weight dimension '{dim}' did not have a match in " +
304304
$"either the input spec '{input_spec}' " +
305305
$"or the output spec '{output_spec}'. " +
@@ -310,7 +310,7 @@ public static (Shape, Shape, Shape) _analyze_split_string(Match split_string,
310310
{
311311
var num_left_elided = left_elided ? elided : 0;
312312
var idx_map = output_spec.Select((_char, i) => (i, _char))
313-
.ToDictionary(_ => _._char, _ => output_shape[_.i + num_left_elided]);
313+
.ToDictionary(_ => _._char, _ => _output_shape[_.i + num_left_elided]);
314314
foreach (var _char in bias_axes)
315315
if (!output_spec.Contains(_char))
316316
throw new ValueError($"Bias dimension '{_char}' was requested," +
@@ -327,7 +327,7 @@ public static (Shape, Shape, Shape) _analyze_split_string(Match split_string,
327327
else bias_shape = null;
328328

329329
return (weight_shape.ToArray(),
330-
(bias_shape ?? new List<long>()).ToArray(),
330+
(bias_shape ?? new List<int>()).ToArray(),
331331
_output_shape.ToArray());
332332
}
333333
}

test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,21 @@ public void test_calculate_scores_multi_dim_concat()
151151
[TestMethod]
152152
public void test_masked_attention()
153153
{
154+
var batch_size = 3;
155+
154156
var query = keras.Input(shape: (4, 8));
155157
var value = keras.Input(shape: (2, 8));
156158
var mask_tensor = keras.Input(shape:(4, 2));
157159
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
158160
attention_layer.Apply(new[] { query, value, mask_tensor });
159161

160-
var from_data = 10 * np.random.randn(3, 4, 8);
161-
var to_data = 10 * np.random.randn(3, 2, 8);
162+
var from_data = 10 * np.random.randn(batch_size, 4, 8);
163+
var to_data = 10 * np.random.randn(batch_size, 2, 8);
162164

163-
var mask_data = np.random.randint(2, size: (3, 4, 2));
165+
var mask_data = np.random.randint(2, size: (batch_size, 4, 2));
164166
var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data });
165167

166-
var null_mask_data = np.ones((3, 4, 2));
168+
var null_mask_data = np.ones((batch_size, 4, 2));
167169
var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data });
168170

169171
Assert.AreNotEqual(masked_output_data, unmasked_output_data);

0 commit comments

Comments
 (0)