Skip to content

Commit d38b12e

Browse files
authored
Merge pull request #156 from evanhaldane/toArray
readArray methods
2 parents 6a48e9e + 511bbc5 commit d38b12e

File tree

3 files changed

+266
-3
lines changed

3 files changed

+266
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ So `my2DArray` is a 2D array of 2x3 size.
110110
Note that a `Tensor` can be a zero dimensional array, which is simply a scalar value.
111111

112112
```
113-
val scalar = Tensor(42.f)
113+
val scalar = Tensor(42.0f)
114114
println(scalar.shape.length) // 0
115115
```
116116

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ trait Tensors extends OpenCL {
10921092
private[compute] def doBuffer: Do[PendingBuffer[closure.JvmValue]]
10931093

10941094
/** Returns a RAII managed asynchronous task to read this [[Tensor]] into an off-heap memory,
1095-
* which is linearized in row-majoy order.
1095+
* which is linearized in row-major order.
10961096
*
10971097
* @group slow
10981098
*/
@@ -1109,14 +1109,142 @@ trait Tensors extends OpenCL {
11091109
}
11101110

11111111
/** Returns an asynchronous task to read this [[Tensor]] into a [[scala.Array]],
1112-
* which is linearized in row-majoy order.
1112+
* which is linearized in row-major order.
11131113
*
11141114
* @group slow
11151115
*/
11161116
def flatArray: Future[Array[closure.JvmValue]] = {
11171117
flatBuffer.intransitiveMap(closure.valueType.memory.toArray).run
11181118
}
11191119

1120+
// Convert flat arrays into multidimensional ones
1121+
private[Tensors] def make2DArray[A:ClassTag](flat:Array[A], shape:Array[Int]): Array[Array[A]] = {
1122+
Array.tabulate(shape(0), shape(1))((i,j) => flat(j + i*shape(1)))
1123+
}
1124+
1125+
private[Tensors] def make3DArray[A:ClassTag](flat:Array[A], shape:Array[Int]): Array[Array[Array[A]]] = {
1126+
Array.tabulate(shape(0), shape(1), shape(2))((i,j,k) => flat(k + j*shape(2) + i*shape(1)*shape(2)))
1127+
}
1128+
1129+
private[Tensors] def make4DArray[A:ClassTag](flat:Array[A], shape:Array[Int]): Array[Array[Array[Array[A]]]] = {
1130+
Array.tabulate(shape(0), shape(1), shape(2), shape(3))(
1131+
(i,j,k,l) => flat(l + k*shape(3) + j*shape(2)*shape(3) + i*shape(1)*shape(2)*shape(3)))
1132+
}
1133+
1134+
private[Tensors] def make5DArray[A:ClassTag](flat:Array[A], shape:Array[Int]): Array[Array[Array[Array[Array[A]]]]] = {
1135+
Array.tabulate(shape(0), shape(1), shape(2), shape(3), shape(4))(
1136+
(i,j,k,l,m) => flat(m + l*shape(4) + k*shape(3)*shape(4) + j*shape(2)*shape(3)*shape(4) + i*shape(1)*shape(2)*shape(3)*shape(4)))
1137+
}
1138+
1139+
private[Tensors] def make2DSeq[A:ClassTag](flat:Seq[A], shape:Seq[Int]): Seq[Seq[A]] = {
1140+
Seq.tabulate(shape(0), shape(1))((i,j) => flat(j + i*shape(1)))
1141+
}
1142+
1143+
private[Tensors] def make3DSeq[A:ClassTag](flat:Seq[A], shape:Seq[Int]): Seq[Seq[Seq[A]]] = {
1144+
Seq.tabulate(shape(0), shape(1), shape(2))((i,j,k) => flat(k + j*shape(2) + i*shape(1)*shape(2)))
1145+
}
1146+
1147+
private[Tensors] def make4DSeq[A:ClassTag](flat:Seq[A], shape:Seq[Int]): Seq[Seq[Seq[Seq[A]]]] = {
1148+
Seq.tabulate(shape(0), shape(1), shape(2), shape(3))(
1149+
(i,j,k,l) => flat(l + k*shape(3) + j*shape(2)*shape(3) + i*shape(1)*shape(2)*shape(3)))
1150+
}
1151+
1152+
private[Tensors] def make5DSeq[A:ClassTag](flat:Seq[A], shape:Seq[Int]): Seq[Seq[Seq[Seq[Seq[A]]]]] = {
1153+
Seq.tabulate(shape(0), shape(1), shape(2), shape(3), shape(4))(
1154+
(i,j,k,l,m) => flat(m + l*shape(4) + k*shape(3)*shape(4) + j*shape(2)*shape(3)*shape(4) + i*shape(1)*shape(2)*shape(3)*shape(4)))
1155+
}
1156+
1157+
/** Returns an asynchronous task to read this [[Tensor]] into a
1158+
* [[scala.Float]]
1159+
*
1160+
* @group slow
1161+
*/
1162+
def readScalar : Future[Float] = {
1163+
flatArray.map(z => z(0))
1164+
}
1165+
1166+
/** Returns an asynchronous task to read this [[Tensor]] into a
1167+
* [[scala.Array]]
1168+
*
1169+
* @group slow
1170+
*/
1171+
def read1DArray : Future[Array[Float]] = {
1172+
flatArray.map(z => z)
1173+
}
1174+
1175+
/** Returns an asynchronous task to read this [[Tensor]] into a 2D [[scala.Array]]
1176+
*
1177+
* @group slow
1178+
*/
1179+
def read2DArray : Future[Array[Array[Float]]] = {
1180+
flatArray.map(z => make2DArray(z,shape))
1181+
}
1182+
1183+
/** Returns an asynchronous task to read this [[Tensor]] into a 3D [[scala.Array]]
1184+
*
1185+
* @group slow
1186+
*/
1187+
def read3DArray : Future[Array[Array[Array[Float]]]] = {
1188+
flatArray.map(z => make3DArray(z,shape))
1189+
}
1190+
1191+
/** Returns an asynchronous task to read this [[Tensor]] into a 4D [[scala.Array]]
1192+
*
1193+
* @group slow
1194+
*/
1195+
def read4DArray : Future[Array[Array[Array[Array[Float]]]]] = {
1196+
flatArray.map(z => make4DArray(z,shape))
1197+
}
1198+
1199+
/** Returns an asynchronous task to read this [[Tensor]] into a 5D [[scala.Array]]
1200+
*
1201+
* @group slow
1202+
*/
1203+
def read5DArray : Future[Array[Array[Array[Array[Array[Float]]]]]] = {
1204+
flatArray.map(z => make5DArray(z,shape))
1205+
}
1206+
1207+
/** Returns an asynchronous task to read this [[Tensor]] into a
1208+
* [[scala.Seq]]
1209+
*
1210+
* @group slow
1211+
*/
1212+
def read1DSeq : Future[Seq[Float]] = {
1213+
flatArray.map(z => z)
1214+
}
1215+
1216+
/** Returns an asynchronous task to read this [[Tensor]] into a 2D [[scala.Seq]]
1217+
*
1218+
* @group slow
1219+
*/
1220+
def read2DSeq : Future[Seq[Seq[Float]]] = {
1221+
flatArray.map(z => make2DSeq(z,shape))
1222+
}
1223+
1224+
/** Returns an asynchronous task to read this [[Tensor]] into a 3D [[scala.Seq]]
1225+
*
1226+
* @group slow
1227+
*/
1228+
def read3DSeq : Future[Seq[Seq[Seq[Float]]]] = {
1229+
flatArray.map(z => make3DSeq(z,shape))
1230+
}
1231+
1232+
/** Returns an asynchronous task to read this [[Tensor]] into a 4D [[scala.Seq]]
1233+
*
1234+
* @group slow
1235+
*/
1236+
def read4DSeq : Future[Seq[Seq[Seq[Seq[Float]]]]] = {
1237+
flatArray.map(z => make4DSeq(z,shape))
1238+
}
1239+
1240+
/** Returns an asynchronous task to read this [[Tensor]] into a 5D [[scala.Seq]]
1241+
*
1242+
* @group slow
1243+
*/
1244+
def read5DSeq : Future[Seq[Seq[Seq[Seq[Seq[Float]]]]]] = {
1245+
flatArray.map(z => make5DSeq(z,shape))
1246+
}
1247+
11201248
/**
11211249
* @group metadata
11221250
*/

Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,142 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
263263
}
264264
.run
265265
.toScalaFuture
266+
267+
"readScalar" in doTensors
268+
.flatMap { tensors =>
269+
Do.garbageCollected(tensors.Tensor(42.0f).readScalar).map {a=>
270+
a should be(42.0f)
271+
}
272+
}
273+
.run
274+
.toScalaFuture
275+
276+
"read1DArray" in doTensors
277+
.flatMap { tensors =>
278+
Do.garbageCollected(tensors.Tensor(Array[Float](1,2)).read1DArray).map {a=>
279+
a should be(Array[Float](1,2))
280+
}
281+
}
282+
.run
283+
.toScalaFuture
284+
285+
"read2DArray" in doTensors
286+
.flatMap { tensors =>
287+
import tensors._
288+
val array = Array(Array[Float](1, 2), Array[Float](3, 4), Array[Float](5,6))
289+
Do.garbageCollected(Tensor(array).read2DArray).map {a=>
290+
a(0) should be(Array[Float](1, 2))
291+
a(1) should be(Array[Float](3, 4))
292+
a(2) should be(Array[Float](5, 6))
293+
}
294+
}
295+
.run
296+
.toScalaFuture
297+
298+
"read3DArray" in doTensors
299+
.flatMap { tensors =>
300+
import tensors._
301+
val array = Array(Array(Array[Float](1, 2), Array[Float](3, 4), Array[Float](5,6)), Array(Array[Float](7, 8), Array[Float](9, 10), Array[Float](11,12)))
302+
Do.garbageCollected(Tensor(array).read3DArray).map { a =>
303+
a(0)(0) should be(Array[Float](1,2))
304+
a(0)(1) should be(Array[Float](3,4))
305+
a(0)(2) should be(Array[Float](5,6))
306+
a(1)(0) should be(Array[Float](7,8))
307+
a(1)(1) should be(Array[Float](9,10))
308+
a(1)(2) should be(Array[Float](11,12))
309+
}
310+
}
311+
.run
312+
.toScalaFuture
313+
314+
"read4DArray" in doTensors
315+
.flatMap { tensors =>
316+
import tensors._
317+
val array = Array(Array(Array(Array[Float](1, 2), Array[Float](3, 4), Array[Float](5,6)),
318+
Array(Array[Float](7, 8), Array[Float](9, 10), Array[Float](11,12))),
319+
Array(Array(Array[Float](13, 14), Array[Float](15, 16), Array[Float](17,18)),
320+
Array(Array[Float](19, 20), Array[Float](21, 22), Array[Float](23,24))))
321+
Do.garbageCollected(Tensor(array).read4DArray).map { a =>
322+
a(0)(0)(0) should be(Array[Float](1,2))
323+
a(0)(0)(1) should be(Array[Float](3,4))
324+
a(0)(0)(2) should be(Array[Float](5,6))
325+
a(0)(1)(0) should be(Array[Float](7,8))
326+
a(0)(1)(1) should be(Array[Float](9,10))
327+
a(0)(1)(2) should be(Array[Float](11,12))
328+
a(1)(0)(0) should be(Array[Float](13,14))
329+
a(1)(0)(1) should be(Array[Float](15,16))
330+
a(1)(0)(2) should be(Array[Float](17,18))
331+
a(1)(1)(0) should be(Array[Float](19,20))
332+
a(1)(1)(1) should be(Array[Float](21,22))
333+
a(1)(1)(2) should be(Array[Float](23,24))
334+
}
335+
}
336+
.run
337+
.toScalaFuture
338+
339+
"read1DSeq" in doTensors
340+
.flatMap { tensors =>
341+
Do.garbageCollected(tensors.Tensor(Seq[Float](1,2)).read1DSeq).map {a=>
342+
a should be(Seq[Float](1,2))
343+
}
344+
}
345+
.run
346+
.toScalaFuture
347+
348+
"read2DSeq" in doTensors
349+
.flatMap { tensors =>
350+
import tensors._
351+
val seq = Seq(Seq[Float](1, 2), Seq[Float](3, 4), Seq[Float](5,6))
352+
Do.garbageCollected(Tensor(seq).read2DSeq).map {a=>
353+
a(0) should be(Seq[Float](1, 2))
354+
a(1) should be(Seq[Float](3, 4))
355+
a(2) should be(Seq[Float](5, 6))
356+
}
357+
}
358+
.run
359+
.toScalaFuture
360+
361+
"read3DSeq" in doTensors
362+
.flatMap { tensors =>
363+
import tensors._
364+
val seq = Seq(Seq(Seq[Float](1, 2), Seq[Float](3, 4), Seq[Float](5,6)), Seq(Seq[Float](7, 8), Seq[Float](9, 10), Seq[Float](11,12)))
365+
Do.garbageCollected(Tensor(seq).read3DSeq).map { a =>
366+
a(0)(0) should be(Seq[Float](1,2))
367+
a(0)(1) should be(Seq[Float](3,4))
368+
a(0)(2) should be(Seq[Float](5,6))
369+
a(1)(0) should be(Seq[Float](7,8))
370+
a(1)(1) should be(Seq[Float](9,10))
371+
a(1)(2) should be(Seq[Float](11,12))
372+
}
373+
}
374+
.run
375+
.toScalaFuture
266376

377+
"read4DSeq" in doTensors
378+
.flatMap { tensors =>
379+
import tensors._
380+
val seq = Seq(Seq(Seq(Seq[Float](1, 2), Seq[Float](3, 4), Seq[Float](5,6)),
381+
Seq(Seq[Float](7, 8), Seq[Float](9, 10), Seq[Float](11,12))),
382+
Seq(Seq(Seq[Float](13, 14), Seq[Float](15, 16), Seq[Float](17,18)),
383+
Seq(Seq[Float](19, 20), Seq[Float](21, 22), Seq[Float](23,24))))
384+
Do.garbageCollected(Tensor(seq).read4DSeq).map { a =>
385+
a(0)(0)(0) should be(Seq[Float](1,2))
386+
a(0)(0)(1) should be(Seq[Float](3,4))
387+
a(0)(0)(2) should be(Seq[Float](5,6))
388+
a(0)(1)(0) should be(Seq[Float](7,8))
389+
a(0)(1)(1) should be(Seq[Float](9,10))
390+
a(0)(1)(2) should be(Seq[Float](11,12))
391+
a(1)(0)(0) should be(Seq[Float](13,14))
392+
a(1)(0)(1) should be(Seq[Float](15,16))
393+
a(1)(0)(2) should be(Seq[Float](17,18))
394+
a(1)(1)(0) should be(Seq[Float](19,20))
395+
a(1)(1)(1) should be(Seq[Float](21,22))
396+
a(1)(1)(2) should be(Seq[Float](23,24))
397+
}
398+
}
399+
.run
400+
.toScalaFuture
401+
267402
"random" in doTensors
268403
.map { tensors =>
269404
import tensors._

0 commit comments

Comments
 (0)