@@ -91,11 +91,69 @@ public bool Run()
9191 // graph: the two constants and matmul.
9292 //
9393 // The output of the op is returned in 'result' as a numpy `ndarray` object.
94- return with ( tf . Session ( ) , sess =>
94+ using ( sess = tf . Session ( ) )
9595 {
9696 var result = sess . run ( product ) ;
9797 Console . WriteLine ( result . ToString ( ) ) ; // ==> [[ 12.]]
98- return result . Data < int > ( ) [ 0 ] == 12 ;
98+ } ;
99+
100+ // `BatchMatMul` is actually embedded into the `MatMul` operation on the tensorflow.dll side. Every time we ask
101+ // for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent
102+ // across the two matrices and a common matrix multiplication is done on the residual 2 dimensions.
103+ //
104+ // np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3)
105+ // array([[[1, 2, 3],
106+ // [4, 5, 6],
107+ // [7, 8, 9]],
108+ //
109+ // [[1, 2, 3],
110+ // [4, 5, 6],
111+ // [7, 8, 9]],
112+ //
113+ // [[1, 2, 3],
114+ // [4, 5, 6],
115+ // [7, 8, 9]]])
116+ var firstTensor = tf . convert_to_tensor (
117+ np . reshape (
118+ np . array < float > ( 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ) ,
119+ 3 , 3 , 3 ) ) ;
120+ //
121+ // np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2)
122+ // array([[[0, 1],
123+ // [0, 1],
124+ // [0, 1]],
125+ //
126+ // [[0, 1],
127+ // [0, 0],
128+ // [1, 0]],
129+ //
130+ // [[1, 0],
131+ // [1, 0],
132+ // [1, 0]]])
133+ var secondTensor = tf . convert_to_tensor (
134+ np . reshape (
135+ np . array < float > ( 0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ) ,
136+ 3 , 3 , 2 ) ) ;
137+ var batchMul = tf . batch_matmul ( firstTensor , secondTensor ) ;
138+ var checkTensor = np . array < float > ( 0 , 6 , 0 , 15 , 0 , 24 , 3 , 1 , 6 , 4 , 9 , 7 , 6 , 0 , 15 , 0 , 24 , 0 ) ;
139+ return with ( tf . Session ( ) , sess =>
140+ {
141+ var result = sess . run ( batchMul ) ;
142+ Console . WriteLine ( result . ToString ( ) ) ;
143+ //
144+ // ==> array([[[0, 6],
145+ // [0, 15],
146+ // [0, 24]],
147+ //
148+ // [[ 3, 1],
149+ // [ 6, 4],
150+ // [ 9, 7]],
151+ //
152+ // [[ 6, 0],
153+ // [15, 0],
154+ // [24, 0]]])
155+ return np . reshape ( result , 18 )
156+ . array_equal ( checkTensor ) ;
99157 } ) ;
100158 }
101159
0 commit comments