Skip to content

Commit 9a57947

Browse files
Andy-Elizabeth-mouseOceania2018
authored andcommitted
Add EinsumDense support and simply test it
1 parent 430ac93 commit 9a57947

File tree

6 files changed

+469
-1
lines changed

6 files changed

+469
-1
lines changed

src/TensorFlowNET.Core/APIs/tf.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public List<T> get_collection<T>(string key, string scope = "")
4444
/// When eager execution is enabled, code inside an init_scope block runs with
4545
/// eager execution enabled even when tracing a `tf.function`.
4646
/// </summary>
47-
public void init_scope()
47+
public ops.NameScope init_scope()
4848
=> ops.init_scope();
4949

5050
/// <summary>
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using System;
2+
using static Tensorflow.Binding;
3+
4+
namespace Tensorflow.Keras.ArgsDefinition
5+
{
6+
public class EinsumDenseArgs : LayerArgs
7+
{
8+
/// <summary>
9+
/// An equation describing the einsum to perform. This equation must
10+
/// be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or
11+
/// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
12+
/// expression sequence.
13+
/// </summary>
14+
public string Equation { get; set; }
15+
16+
/// <summary>
17+
/// The expected shape of the output tensor (excluding the batch
18+
/// dimension and any dimensions represented by ellipses). You can specify
19+
/// None for any dimension that is unknown or can be inferred from the input
20+
/// shape.
21+
/// </summary>
22+
public Shape OutputShape { get; set; }
23+
24+
/// <summary>
25+
/// A string containing the output dimension(s) to apply a bias to.
26+
/// Each character in the `bias_axes` string should correspond to a character
27+
/// in the output portion of the `equation` string.
28+
/// </summary>
29+
public string BiasAxes { get; set; } = null;
30+
31+
/// <summary>
32+
/// Activation function to use.
33+
/// </summary>
34+
public Activation Activation { get; set; }
35+
36+
/// <summary>
37+
/// Initializer for the `kernel` weights matrix.
38+
/// </summary>
39+
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
40+
41+
/// <summary>
42+
/// Initializer for the bias vector.
43+
/// </summary>
44+
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
45+
46+
/// <summary>
47+
/// Regularizer function applied to the `kernel` weights matrix.
48+
/// </summary>
49+
public IRegularizer KernelRegularizer { get; set; }
50+
51+
/// <summary>
52+
/// Regularizer function applied to the bias vector.
53+
/// </summary>
54+
public IRegularizer BiasRegularizer { get; set; }
55+
56+
/// <summary>
57+
/// Constraint function applied to the `kernel` weights matrix.
58+
/// </summary>
59+
public Action KernelConstraint { get; set; }
60+
61+
/// <summary>
62+
/// Constraint function applied to the bias vector.
63+
/// </summary>
64+
public Action BiasConstraint { get; set; }
65+
}
66+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using static Tensorflow.Binding;
3+
4+
namespace Tensorflow.Keras.ArgsDefinition
5+
{
6+
public class MultiHeadAttentionArgs : LayerArgs
7+
{
8+
public int NumHeads { get; set; }
9+
public int KeyDim { get; set; }
10+
public int? ValueDim { get; set; } = null;
11+
public float Dropout { get; set; } = 0f;
12+
public bool UseBias { get; set; } = true;
13+
public Shape OutputShape { get; set; } = null;
14+
public Shape AttentionAxis { get; set; } = null;
15+
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
16+
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
17+
public IRegularizer KernelRegularizer { get; set; } = null;
18+
public IRegularizer BiasRegularizer { get; set; } = null;
19+
public Action KernelConstraint { get; set; } = null;
20+
public Action BiasConstraint { get; set; } = null;
21+
}
22+
}

0 commit comments

Comments
 (0)