Skip to content

Commit 316ad5f

Browse files
authored
Merge pull request #280 from acifonelli/master
Add `BatchMatMul` operation
2 parents c62d9ee + 0f4cb1a commit 316ad5f

File tree

5 files changed

+122
-2
lines changed

5 files changed

+122
-2
lines changed

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,8 @@ public static Tensor diag(Tensor diagonal, string name = null)
1111

1212
public static Tensor matmul(Tensor a, Tensor b)
1313
=> gen_math_ops.mat_mul(a, b);
14+
15+
public static Tensor batch_matmul(Tensor x, Tensor y)
16+
=> gen_math_ops.batch_mat_mul(x, y);
1417
}
1518
}

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
153153
return new Tensor[] { grad_a, grad_b };
154154
}
155155

156+
public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads)
157+
{
158+
throw new NotImplementedException();
159+
}
160+
156161
[RegisterGradient("Mean")]
157162
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
158163
{

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,41 @@ public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool
471471
return _op.outputs[0];
472472
}
473473

474+
/// <summary>
475+
/// Multiply slices of the two matrices "x" and "y".
476+
/// </summary>
477+
/// <remarks>
478+
/// The `BatchMatMul` operation is embedded into the
479+
/// `MatMul` operation on the DLL side. However the expected
480+
/// attributes are not the same, hence we need to expose this
481+
/// method to have the right args list on the `_apply_op_helper`
482+
/// function.
483+
///
484+
/// For each rank > 2 the first rank - 2 dimensions are considered
485+
/// as fixed, and have to be consistent across the two matrices. A
486+
/// common matrix multiplication is then applied over the residual
487+
/// 2 dimensions.
488+
///
489+
/// e.g.
490+
/// x is (3, 6, 12); y is (3, 12, 6)
491+
/// batch_matmul(x, y) ==> (3, 6, 6)
492+
/// </remarks>
493+
/// <param name="x"></param>
494+
/// <param name="y"></param>
495+
/// <param name="adj_x"></param>
496+
/// <param name="adj_y"></param>
497+
/// <param name="name"></param>
498+
/// <returns></returns>
499+
public static Tensor batch_mat_mul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
500+
{
501+
var _op = _op_def_lib._apply_op_helper(
502+
"BatchMatMul",
503+
name,
504+
args: new { x, y, adj_x, adj_y });
505+
506+
return _op.outputs[0];
507+
}
508+
474509
/// <summary>
475510
/// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
476511
/// </summary>

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,25 @@ public static Tensor matmul(Tensor a, Tensor b,
497497
return result;
498498
}
499499

500+
public static Tensor batch_matmul(Tensor x, Tensor y,
501+
bool adj_x = false, bool adj_y = false,
502+
string name = null)
503+
{
504+
Tensor result = null;
505+
506+
with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope =>
507+
{
508+
name = scope;
509+
510+
x = ops.convert_to_tensor(x, name: "a");
511+
y = ops.convert_to_tensor(y, name: "b");
512+
513+
result = gen_math_ops.batch_mat_mul(x, y, adj_x, adj_y, name);
514+
});
515+
516+
return result;
517+
}
518+
500519
/// <summary>
501520
/// Returns the complex conjugate of a complex number.
502521
/// </summary>

test/TensorFlowNET.Examples/BasicOperations.cs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,69 @@ public bool Run()
9191
// graph: the two constants and matmul.
9292
//
9393
// The output of the op is returned in 'result' as a numpy `ndarray` object.
94-
return with(tf.Session(), sess =>
94+
using (sess = tf.Session())
9595
{
9696
var result = sess.run(product);
9797
Console.WriteLine(result.ToString()); // ==> [[ 12.]]
98-
return result.Data<int>()[0] == 12;
98+
};
99+
100+
// `BatchMatMul` is actually embedded into the `MatMul` operation on the tensorflow.dll side. Every time we ask
101+
// for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent
102+
// across the two matrices and a common matrix multiplication is done on the residual 2 dimensions.
103+
//
104+
// np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3)
105+
// array([[[1, 2, 3],
106+
// [4, 5, 6],
107+
// [7, 8, 9]],
108+
//
109+
// [[1, 2, 3],
110+
// [4, 5, 6],
111+
// [7, 8, 9]],
112+
//
113+
// [[1, 2, 3],
114+
// [4, 5, 6],
115+
// [7, 8, 9]]])
116+
var firstTensor = tf.convert_to_tensor(
117+
np.reshape(
118+
np.array<float>(1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9),
119+
3, 3, 3));
120+
//
121+
// np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2)
122+
// array([[[0, 1],
123+
// [0, 1],
124+
// [0, 1]],
125+
//
126+
// [[0, 1],
127+
// [0, 0],
128+
// [1, 0]],
129+
//
130+
// [[1, 0],
131+
// [1, 0],
132+
// [1, 0]]])
133+
var secondTensor = tf.convert_to_tensor(
134+
np.reshape(
135+
np.array<float>(0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0),
136+
3, 3, 2));
137+
var batchMul = tf.batch_matmul(firstTensor, secondTensor);
138+
var checkTensor = np.array<float>(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0);
139+
return with(tf.Session(), sess =>
140+
{
141+
var result = sess.run(batchMul);
142+
Console.WriteLine(result.ToString());
143+
//
144+
// ==> array([[[0, 6],
145+
// [0, 15],
146+
// [0, 24]],
147+
//
148+
// [[ 3, 1],
149+
// [ 6, 4],
150+
// [ 9, 7]],
151+
//
152+
// [[ 6, 0],
153+
// [15, 0],
154+
// [24, 0]]])
155+
return np.reshape(result, 18)
156+
.array_equal(checkTensor);
99157
});
100158
}
101159

0 commit comments

Comments
 (0)