Skip to content

Commit 50f141d

Browse files
authored
feat: Add transitions for fix-sized arrays (#192)
Fixes #180
1 parent a6a73fc commit 50f141d

File tree

13 files changed

+225
-49
lines changed

13 files changed

+225
-49
lines changed

core/src/fr/hammons/slinc/SetSizeArray.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class SetSizeArray[A, B <: Int] private[slinc] (private val array: Array[A])
4444
value: A
4545
)(using 0 <= C =:= true, C < B =:= true): Unit = array(constValue[C]) = value
4646

47+
def zip[C](oArr: SetSizeArray[C, B]): SetSizeArray[(A, C), B] =
48+
new SetSizeArray[(A, C), B](array.zip(oArr.array))
49+
def foreach(fn: A => Unit) = array.foreach(fn)
50+
4751
object SetSizeArray:
4852
class SetSizeArrayBuilderUnsafe[B <: Int]:
4953
def apply[A](array: Array[A]): SetSizeArray[A, B] = new SetSizeArray(array)

core/src/fr/hammons/slinc/TypeDescriptor.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,20 @@ case class SetSizeArrayDescriptor(
239239
override val argumentTransition
240240
: (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[
241241
Inner
242-
] = ???
242+
] = arg =>
243+
val mem = summon[Allocator].allocate(this, 1)
244+
summon[ReadWriteModule].write(
245+
mem,
246+
Bytes(0),
247+
this,
248+
arg
249+
)
250+
mem.asAddress
243251

244252
override val returnTransition
245-
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = ???
253+
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] =
254+
obj =>
255+
val mem = summon[TransitionModule].addressReturn(obj)
256+
summon[ReadWriteModule].read(mem, Bytes(0), this)
246257

247258
type Inner = SetSizeArray[contained.Inner, ?]

core/test/resources/native/test.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,12 @@ EXPORTED struct_issue_175 i175_test(struct_issue_175 a, char left) {
139139
}
140140
return a;
141141
}
142+
143+
EXPORTED int* i180_test(int my_array[5]) {
144+
int i = 0;
145+
while(i < 5) {
146+
my_array[i] = my_array[i] * 2;
147+
i++;
148+
}
149+
return my_array;
150+
}

core/test/src/fr/hammons/slinc/BindingSpec.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ trait BindingSpec(val slinc: Slinc) extends ScalaCheckSuite:
5959
left: CChar
6060
): I175_Struct
6161

62+
def i180_test(
63+
input: SetSizeArray[CInt, 5]
64+
): SetSizeArray[CInt, 5]
65+
6266
test("int_identity") {
6367
val test = FSet.instance[TestLib]
6468

@@ -186,3 +190,11 @@ trait BindingSpec(val slinc: Slinc) extends ScalaCheckSuite:
186190
union.set(double)
187191
val res = test.i175_test(I175_Struct(union), 0)
188192
assertEquals(res.union.get[CDouble], double / 2)
193+
194+
property("issue 180 - can send and receive set size arrays to C functions"):
195+
val test = FSet.instance[TestLib]
196+
forAll(Gen.listOfN(5, Arbitrary.arbitrary[CInt])): list =>
197+
val arr = SetSizeArray.fromArrayUnsafe[5](list.toArray)
198+
val retArr = test.i180_test(arr)
199+
200+
retArr.zip(arr.map(_ * 2)).foreach(assertEquals(_, _))

core/test/src/fr/hammons/slinc/TransferSpec.scala

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import scala.concurrent.Await
1010
import scala.concurrent.duration.Duration
1111
import scala.concurrent.ExecutionContext.Implicits.global
1212
import scala.reflect.ClassTag
13+
import scala.util.chaining.*
1314

1415
trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
1516
ClassTag[ThreadException]
@@ -28,7 +29,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
2829

2930
case class F(u: CUnion[(CInt, CFloat)]) derives Struct
3031

31-
case class G(arr: SetSizeArray[CLong, 2]) derives Struct
32+
case class G(long: CLong, arr: SetSizeArray[CLong, 2]) derives Struct
3233

3334
test("can read and write jvm ints") {
3435
Scope.global {
@@ -162,20 +163,16 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
162163
}
163164
}
164165

165-
test("varargs can be sent and retrieved"):
166+
test("varargs can receive primitive types"):
166167
Scope.confined {
167-
val vaListForVaList = VarArgsBuilder(4).build
168168
val vaList = VarArgsBuilder(
169169
4.toByte,
170170
5.toShort,
171171
6,
172172
7.toLong,
173173
2f,
174174
3d,
175-
Null[Int],
176-
A(1, 2),
177-
CLong(4: Byte),
178-
vaListForVaList
175+
Null[Int]
179176
).build
180177

181178
assertEquals(vaList.get[Byte], 4.toByte, "byte assert")
@@ -185,24 +182,118 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
185182
assertEquals(vaList.get[Float], 2f, "float assert")
186183
assertEquals(vaList.get[Double], 3d, "double assert")
187184
assertEquals(
188-
vaList.get[Ptr[Int]].mem.asAddress,
189-
Null[Int].mem.asAddress,
185+
vaList.get[Ptr[Int]],
186+
Null[Int],
190187
"ptr assert"
191188
)
189+
}
190+
191+
test("varargs can receive complex types".ignore):
192+
Scope.confined {
193+
val vaListForVaList = VarArgsBuilder(4).build
194+
val vaList = VarArgsBuilder(
195+
A(1, 2),
196+
CLong(4),
197+
A(3, 4),
198+
SetSizeArray(1, 2, 3, 4),
199+
vaListForVaList,
200+
CUnion[(CInt, CFloat)].tap(_.set(5)),
201+
// Null[Int],
202+
A(3, 4)
203+
).build
204+
192205
assertEquals(vaList.get[A], A(1, 2), "struct assert")
193206
assertEquals(vaList.get[CLong], CLong(4: Byte), "alias assert")
194-
assertEquals(vaList.get[VarArgs].get[CInt], 4)
207+
assertEquals(vaList.get[A], A(3, 4))
208+
assertEquals(
209+
vaList.get[SetSizeArray[CInt, 4]].toSeq,
210+
Seq(1, 2, 3, 4),
211+
"set size array assert"
212+
)
213+
assertEquals(
214+
vaListForVaList.get[VarArgs].get[Int],
215+
4
216+
)
217+
assertEquals(
218+
vaList.get[CUnion[(CLongLong, CFloat)]].get[CLongLong],
219+
5L,
220+
"cunion assert"
221+
)
222+
// assertEquals(
223+
// vaList.get[Ptr[Int]],
224+
// Null[Int]
225+
// )
226+
assertEquals(
227+
vaList.get[A],
228+
A(3, 4),
229+
"struct assert 2"
230+
)
195231
}
196232

197-
test("varargs can be skipped"):
233+
test("varargs can skip primitive types"):
198234
Scope.confined {
199235
val vaList = VarArgsBuilder(
200-
4.toByte,
201-
2f
236+
4: Byte,
237+
5: Short,
238+
6,
239+
7L,
240+
2f,
241+
3d,
242+
Null[Int]
202243
).build
203244

245+
val vaList2 = vaList.copy()
246+
204247
vaList.skip[Byte]
205-
assertEquals(vaList.get[Float], 2f)
248+
assertEquals(vaList.get[Short], 5: Short)
249+
vaList.skip[Int]
250+
assertEquals(vaList.get[Long], 7L)
251+
vaList.skip[Float]
252+
assertEquals(vaList.get[Double], 3d)
253+
vaList.skip[Ptr[Int]]
254+
255+
assertEquals(vaList2.get[Byte], 4: Byte)
256+
vaList2.skip[Short]
257+
assertEquals(vaList2.get[Int], 6)
258+
vaList2.skip[Long]
259+
assertEquals(vaList2.get[Float], 2f)
260+
vaList2.skip[Double]
261+
assertEquals(vaList2.get[Ptr[Int]], Null[Int])
262+
}
263+
264+
test("varargs can skip complex types".ignore):
265+
Scope.confined {
266+
val vaListForVaList = VarArgsBuilder(4, 5, 6).build
267+
val vaList = VarArgsBuilder(
268+
A(1, 2),
269+
CLong(4),
270+
vaListForVaList,
271+
CUnion[(CInt, CFloat)].tap(_.set(5)),
272+
SetSizeArray(1, 2, 3, 4)
273+
).build
274+
275+
val vaList2 = vaList.copy()
276+
277+
assertEquals(vaList.get[A], A(1, 2), "struct assert")
278+
vaList.skip[CLong]
279+
val vaList3 = vaList.get[VarArgs]
280+
assertEquals(
281+
List(vaList3.get[Int], vaList3.get[Int], vaList3.get[Int]),
282+
List(4, 5, 6),
283+
"varargs assert"
284+
)
285+
vaList.skip[CUnion[(CInt, CFloat)]]
286+
assertEquals(
287+
vaList.get[SetSizeArray[Int, 4]].toSeq,
288+
Seq(1, 2, 3, 4),
289+
"set size array assert"
290+
)
291+
292+
vaList2.skip[A]
293+
assertEquals(vaList2.get[CLong], CLong(4))
294+
vaList2.skip[VarArgs]
295+
assertEquals(vaList2.get[CUnion[(CInt, CFloat)]].get[Int], 5)
296+
vaList2.skip[SetSizeArray[Int, 4]]
206297
}
207298

208299
test("varargs can be copied and reread"):
@@ -373,7 +464,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
373464
}
374465

375466
test("can copy G to native memory and back"):
376-
val g = G(SetSizeArray(CLong(1), CLong(2)))
467+
val g = G(CLong(5), SetSizeArray(CLong(1), CLong(2)))
377468

378469
Scope.confined {
379470
val ptr = Ptr.copy(g)

j17/src/fr/hammons/slinc/Allocator17.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ object Allocator17:
122122
case ms: MemorySegment => ms
123123
case _ => throw Error("base of mem was not J17 MemorySegment!!")
124124
)
125+
case (ssad: SetSizeArrayDescriptor, s: SetSizeArray[?, ?]) =>
126+
LinkageModule17.tempScope(alloc ?=>
127+
builder.vargFromAddress(
128+
C_POINTER,
129+
transitionModule17
130+
.methodArgument(ssad, s, alloc)
131+
.asInstanceOf[Addressable]
132+
)
133+
)
125134
case (a, d) =>
126135
throw Error(
127136
s"Unsupported type descriptor/data pairing for VarArgs: $a - $d"

j17/src/fr/hammons/slinc/VarArgs17.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package fr.hammons.slinc
33
import jdk.incubator.foreign.CLinker.VaList
44
import jdk.incubator.foreign.CLinker.{C_INT, C_LONG_LONG, C_DOUBLE, C_POINTER}
55
import jdk.incubator.foreign.SegmentAllocator
6+
import jdk.incubator.foreign.GroupLayout
67
import fr.hammons.slinc.modules.{
78
LinkageModule17,
89
descriptorModule17,
@@ -20,7 +21,8 @@ class VarArgs17(args: VaList) extends VarArgs:
2021
case LongDescriptor => Long.box(args.vargAsLong(C_LONG_LONG))
2122
case FloatDescriptor => Float.box(args.vargAsDouble(C_DOUBLE).toFloat)
2223
case DoubleDescriptor => Double.box(args.vargAsDouble(C_DOUBLE))
23-
case PtrDescriptor => args.vargAsAddress(C_POINTER).nn
24+
case PtrDescriptor | _: SetSizeArrayDescriptor | VaListDescriptor =>
25+
args.vargAsAddress(C_POINTER).nn
2426
case sd: StructDescriptor =>
2527
LinkageModule17.tempScope(alloc ?=>
2628
args
@@ -30,26 +32,34 @@ class VarArgs17(args: VaList) extends VarArgs:
3032
)
3133
.nn
3234
)
33-
case AliasDescriptor(real) => get(real)
34-
case VaListDescriptor => args.vargAsAddress(C_POINTER).nn
35-
case CUnionDescriptor(possibleTypes) => get(possibleTypes.maxBy(_.size))
35+
case AliasDescriptor(real) => get(real)
36+
case cud: CUnionDescriptor =>
37+
LinkageModule17.tempScope(alloc ?=>
38+
args
39+
.vargAsSegment(
40+
descriptorModule17.toMemoryLayout(cud).asInstanceOf[GroupLayout],
41+
alloc.base.asInstanceOf[SegmentAllocator]
42+
)
43+
.nn
44+
)
3645
def get[A](using d: DescriptorOf[A]): A =
3746
transitionModule17.methodReturn[A](d.descriptor, get(d.descriptor))
3847

3948
private def skip(td: TypeDescriptor): Unit =
4049
td match
41-
case ByteDescriptor => args.skip(C_INT)
42-
case ShortDescriptor => args.skip(C_INT)
43-
case IntDescriptor => args.skip(C_INT)
44-
case LongDescriptor => args.skip(C_LONG_LONG)
45-
case FloatDescriptor => args.skip(C_DOUBLE)
46-
case DoubleDescriptor => args.skip(C_DOUBLE)
47-
case PtrDescriptor => args.skip(C_POINTER)
50+
case ByteDescriptor => args.skip(C_INT)
51+
case ShortDescriptor => args.skip(C_INT)
52+
case IntDescriptor => args.skip(C_INT)
53+
case LongDescriptor => args.skip(C_LONG_LONG)
54+
case FloatDescriptor => args.skip(C_DOUBLE)
55+
case DoubleDescriptor => args.skip(C_DOUBLE)
56+
case PtrDescriptor | _: SetSizeArrayDescriptor => args.skip(C_POINTER)
4857
case sd: StructDescriptor =>
4958
args.skip(descriptorModule17.toGroupLayout(sd))
50-
case AliasDescriptor(real) => skip(real)
51-
case VaListDescriptor => args.skip(C_POINTER)
52-
case CUnionDescriptor(possibleTypes) => skip(possibleTypes.maxBy(_.size))
59+
case AliasDescriptor(real) => skip(real)
60+
case VaListDescriptor => args.skip(C_POINTER)
61+
case cud: CUnionDescriptor =>
62+
args.skip(descriptorModule17.toMemoryLayout(cud))
5363

5464
def skip[A](using dO: DescriptorOf[A]): Unit = skip(dO.descriptor)
5565

j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import jdk.incubator.foreign.{
88
MemorySegment,
99
GroupLayout,
1010
CLinker,
11-
ValueLayout
11+
ValueLayout,
12+
SequenceLayout
1213
}, CLinker.C_POINTER
1314
import scala.collection.concurrent.TrieMap
1415
import fr.hammons.slinc.types.{arch, os, OS, Arch}
@@ -26,11 +27,10 @@ given descriptorModule17: DescriptorModule with
2627
case FloatDescriptor => classOf[Float]
2728
case DoubleDescriptor => classOf[Double]
2829
case PtrDescriptor => classOf[MemoryAddress]
29-
case _: StructDescriptor | _: CUnionDescriptor |
30-
_: SetSizeArrayDescriptor =>
30+
case _: StructDescriptor | _: CUnionDescriptor =>
3131
classOf[MemorySegment]
32-
case VaListDescriptor => classOf[MemoryAddress]
33-
case ad: AliasDescriptor[?] => toCarrierType(ad.real)
32+
case VaListDescriptor | _: SetSizeArrayDescriptor => classOf[MemoryAddress]
33+
case ad: AliasDescriptor[?] => toCarrierType(ad.real)
3434

3535
def genLayoutList(
3636
layouts: Seq[MemoryLayout],
@@ -123,6 +123,11 @@ given descriptorModule17: DescriptorModule with
123123
case CUnionDescriptor(possibleTypes) =>
124124
MemoryLayout.unionLayout(possibleTypes.map(toMemoryLayout).toSeq*).nn
125125

126+
def toDowncallLayout(td: TypeDescriptor): MemoryLayout = toMemoryLayout(
127+
td
128+
) match
129+
case _: SequenceLayout => C_POINTER.nn
130+
case o => o
126131
def toMemoryLayout(smd: StructMemberDescriptor): MemoryLayout =
127132
toMemoryLayout(smd.descriptor).withName(smd.name).nn
128133

j17/src/fr/hammons/slinc/modules/LinkageModule17.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ object LinkageModule17 extends LinkageModule:
2626
varArgs.view.map(_.use[DescriptorOf](d ?=> _ => d.descriptor))
2727
val fdConstructor = descriptor.returnDescriptor match
2828
case None => FunctionDescriptor.ofVoid(_*)
29-
case Some(value) => FunctionDescriptor.of(toMemoryLayout(value), _*)
29+
case Some(value) => FunctionDescriptor.of(toDowncallLayout(value), _*)
3030

3131
val fd = fdConstructor(
3232
descriptor.inputDescriptors.view
33-
.map(toMemoryLayout)
33+
.map(toDowncallLayout)
3434
.concat(variadicDescriptors.map(toMemoryLayout).map(CLinker.asVarArg))
3535
.toSeq
3636
)

j19/src/fr/hammons/slinc/Allocator19.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ class Allocator19(
112112
case ms: MemorySegment => ms
113113
case _ => throw Error("Illegal datatype")
114114
)
115+
116+
case (ssad: SetSizeArrayDescriptor, s: SetSizeArray[?, ?]) =>
117+
LinkageModule19.tempScope(alloc ?=>
118+
builder.addVarg(
119+
ValueLayout.ADDRESS,
120+
transitionModule19
121+
.methodArgument(ssad, s, alloc)
122+
.asInstanceOf[Addressable]
123+
)
124+
)
115125
case (td, d) =>
116126
throw Error(s"Unsupported datatype for $td - $d")
117127

0 commit comments

Comments
 (0)