Skip to content

Commit c6a56cd

Browse files
authored
jextract: fix protocols that return java classes. (#479)
* fix protocols * remove comments * comments * cleanup
1 parent 7519b4c commit c6a56cd

File tree

9 files changed

+78
-22
lines changed

9 files changed

+78
-22
lines changed

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ let package = Package(
460460
.product(name: "SwiftSyntax", package: "swift-syntax"),
461461
.product(name: "SwiftSyntaxBuilder", package: "swift-syntax"),
462462
.product(name: "ArgumentParser", package: "swift-argument-parser"),
463+
.product(name: "OrderedCollections", package: "swift-collections"),
463464
"JavaTypes",
464465
"SwiftJavaShared",
465466
"SwiftJavaConfigurationShared",

Samples/SwiftJavaExtractJNISampleApp/Sources/MySwiftLibrary/ConcreteProtocolAB.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public class ConcreteProtocolAB: ProtocolA, ProtocolB {
2121
return "ConcreteProtocolAB"
2222
}
2323

24+
public func makeClass() -> MySwiftClass {
25+
return MySwiftClass(x: 10, y: 50)
26+
}
27+
2428
public init(constantA: Int64, constantB: Int64) {
2529
self.constantA = constantA
2630
self.constantB = constantB

Samples/SwiftJavaExtractJNISampleApp/Sources/MySwiftLibrary/ProtocolA.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public protocol ProtocolA {
1717
var mutable: Int64 { get set }
1818

1919
func name() -> String
20+
func makeClass() -> MySwiftClass
2021
}
2122

2223
public func takeProtocol(_ proto1: any ProtocolA, _ proto2: some ProtocolA) -> Int64 {

Samples/SwiftJavaExtractJNISampleApp/src/test/java/com/example/swift/ProtocolCallbacksTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public String withString(String input) {
7474
public void withVoid() {}
7575

7676
@Override
77-
public MySwiftClass withObject(MySwiftClass input) {
77+
public MySwiftClass withObject(MySwiftClass input, SwiftArena swiftArena$) {
7878
return input;
7979
}
8080

@@ -84,7 +84,7 @@ public OptionalLong withOptionalInt64(OptionalLong input) {
8484
}
8585

8686
@Override
87-
public Optional<MySwiftClass> withOptionalObject(Optional<MySwiftClass> input) {
87+
public Optional<MySwiftClass> withOptionalObject(Optional<MySwiftClass> input, SwiftArena swiftArena$) {
8888
return input;
8989
}
9090
}

Samples/SwiftJavaExtractJNISampleApp/src/test/java/com/example/swift/ProtocolTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ void protocolMethod() {
7373
}
7474
}
7575

76+
@Test
77+
void protocolClassMethod() {
78+
try (var arena = SwiftArena.ofConfined()) {
79+
ProtocolA proto1 = ConcreteProtocolAB.init(10, 5, arena);
80+
assertEquals(10, proto1.makeClass().getX());
81+
}
82+
}
83+
7684
static class JavaStorage implements Storage {
7785
StorageItem item;
7886

@@ -81,7 +89,7 @@ static class JavaStorage implements Storage {
8189
}
8290

8391
@Override
84-
public StorageItem load() {
92+
public StorageItem load(SwiftArena swiftArena$) {
8593
return item;
8694
}
8795

Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaBindingsPrinting.swift

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,17 @@ extension JNISwift2JavaGenerator {
144144

145145
printer.printBraceBlock("public interface \(decl.swiftNominal.name)\(extendsString)") { printer in
146146
for initializer in decl.initializers {
147-
printFunctionDowncallMethods(&printer, initializer, skipMethodBody: true, skipArenas: true)
147+
printFunctionDowncallMethods(&printer, initializer, skipMethodBody: true)
148148
printer.println()
149149
}
150150

151151
for method in decl.methods {
152-
printFunctionDowncallMethods(&printer, method, skipMethodBody: true, skipArenas: true)
152+
printFunctionDowncallMethods(&printer, method, skipMethodBody: true)
153153
printer.println()
154154
}
155155

156156
for variable in decl.variables {
157-
printFunctionDowncallMethods(&printer, variable, skipMethodBody: true, skipArenas: true)
157+
printFunctionDowncallMethods(&printer, variable, skipMethodBody: true)
158158
printer.println()
159159
}
160160
}
@@ -420,16 +420,15 @@ extension JNISwift2JavaGenerator {
420420
printer.print("record _NativeParameters(\(nativeParameters.joined(separator: ", "))) {}")
421421
}
422422

423-
self.printJavaBindingWrapperMethod(&printer, translatedCase.getAsCaseFunction, skipMethodBody: false, skipArenas: false)
423+
self.printJavaBindingWrapperMethod(&printer, translatedCase.getAsCaseFunction, skipMethodBody: false)
424424
printer.println()
425425
}
426426
}
427427

428428
private func printFunctionDowncallMethods(
429429
_ printer: inout CodePrinter,
430430
_ decl: ImportedFunc,
431-
skipMethodBody: Bool = false,
432-
skipArenas: Bool = false
431+
skipMethodBody: Bool = false
433432
) {
434433
guard translatedDecl(for: decl) != nil else {
435434
// Failed to translate. Skip.
@@ -440,7 +439,7 @@ extension JNISwift2JavaGenerator {
440439

441440
printJavaBindingWrapperHelperClass(&printer, decl)
442441

443-
printJavaBindingWrapperMethod(&printer, decl, skipMethodBody: skipMethodBody, skipArenas: skipArenas)
442+
printJavaBindingWrapperMethod(&printer, decl, skipMethodBody: skipMethodBody)
444443
}
445444

446445
/// Print the helper type container for a user-facing Java API.
@@ -486,21 +485,19 @@ extension JNISwift2JavaGenerator {
486485
private func printJavaBindingWrapperMethod(
487486
_ printer: inout CodePrinter,
488487
_ decl: ImportedFunc,
489-
skipMethodBody: Bool,
490-
skipArenas: Bool
488+
skipMethodBody: Bool
491489
) {
492490
guard let translatedDecl = translatedDecl(for: decl) else {
493491
fatalError("Decl was not translated, \(decl)")
494492
}
495-
printJavaBindingWrapperMethod(&printer, translatedDecl, importedFunc: decl, skipMethodBody: skipMethodBody, skipArenas: skipArenas)
493+
printJavaBindingWrapperMethod(&printer, translatedDecl, importedFunc: decl, skipMethodBody: skipMethodBody)
496494
}
497495

498496
private func printJavaBindingWrapperMethod(
499497
_ printer: inout CodePrinter,
500498
_ translatedDecl: TranslatedFunctionDecl,
501499
importedFunc: ImportedFunc? = nil,
502-
skipMethodBody: Bool,
503-
skipArenas: Bool
500+
skipMethodBody: Bool
504501
) {
505502
var modifiers = ["public"]
506503
if translatedDecl.isStatic {
@@ -531,14 +528,28 @@ extension JNISwift2JavaGenerator {
531528
let parametersStr = parameters.joined(separator: ", ")
532529

533530
// Print default global arena variation
531+
// If we have enabled javaCallbacks we must emit default
532+
// arena methods for protocols, as this is what
533+
// Swift will call into, when you call a interface from Swift.
534+
let shouldGenerateGlobalArenaVariation: Bool
535+
let isParentProtocol = importedFunc?.parentType?.asNominalType?.isProtocol ?? false
536+
534537
if config.effectiveMemoryManagementMode.requiresGlobalArena && translatedSignature.requiresSwiftArena {
538+
shouldGenerateGlobalArenaVariation = true
539+
} else if isParentProtocol, translatedSignature.requiresSwiftArena, config.effectiveEnableJavaCallbacks {
540+
shouldGenerateGlobalArenaVariation = true
541+
} else {
542+
shouldGenerateGlobalArenaVariation = false
543+
}
544+
545+
if shouldGenerateGlobalArenaVariation {
535546
if let importedFunc {
536547
printDeclDocumentation(&printer, importedFunc)
537548
}
538549
var modifiers = modifiers
539550

540551
// If we are a protocol, we emit this as default method
541-
if importedFunc?.parentType?.asNominalTypeDeclaration?.kind == .protocol {
552+
if isParentProtocol {
542553
modifiers.insert("default", at: 1)
543554
}
544555

@@ -555,7 +566,7 @@ extension JNISwift2JavaGenerator {
555566
printer.println()
556567
}
557568

558-
if translatedSignature.requiresSwiftArena, !skipArenas {
569+
if translatedSignature.requiresSwiftArena {
559570
parameters.append("SwiftArena swiftArena$")
560571
}
561572
if let importedFunc {

Sources/SwiftJavaConfigurationShared/Configuration.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ public struct Configuration: Codable {
6565
asyncFuncMode ?? .default
6666
}
6767

68-
public var enableJavaCallbacks: Bool? // FIXME: default it to false, but that plays not nice with Codable
68+
public var enableJavaCallbacks: Bool?
69+
public var effectiveEnableJavaCallbacks: Bool {
70+
enableJavaCallbacks ?? false
71+
}
6972

7073
public var generatedJavaSourcesListFileOutput: String?
7174

Tests/JExtractSwiftTests/JNI/JNIProtocolTests.swift

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ struct JNIProtocolTests {
3232
3333
public protocol B {}
3434
35-
public class SomeClass: SomeProtocol {}
35+
public class SomeClass: SomeProtocol {
36+
public func makeClass() -> SomeClass {}
37+
}
3638
3739
public func takeProtocol(x: some SomeProtocol, y: any SomeProtocol)
3840
public func takeGeneric<S: SomeProtocol>(s: S)
@@ -61,7 +63,29 @@ struct JNIProtocolTests {
6163
...
6264
public void method();
6365
...
64-
public SomeClass withObject(SomeClass c);
66+
public SomeClass withObject(SomeClass c, SwiftArena swiftArena$);
67+
...
68+
}
69+
"""
70+
])
71+
}
72+
73+
@Test
74+
func emitsDefault() throws {
75+
try assertOutput(
76+
input: source,
77+
config: config,
78+
.jni, .java,
79+
detectChunkByInitialLines: 1,
80+
expectedChunks: [
81+
"""
82+
public interface SomeProtocol {
83+
...
84+
public default SomeClass withObject(SomeClass c) {
85+
return withObject(c, SwiftMemoryManagement.GLOBAL_SWIFT_JAVA_ARENA);
86+
}
87+
...
88+
public SomeClass withObject(SomeClass c, SwiftArena swiftArena$);
6589
...
6690
}
6791
"""
@@ -78,7 +102,11 @@ struct JNIProtocolTests {
78102
expectedChunks: [
79103
"""
80104
public final class SomeClass implements JNISwiftInstance, SomeProtocol {
81-
"""
105+
...
106+
public SomeClass makeClass(SwiftArena swiftArena$) {
107+
...
108+
}
109+
""",
82110
])
83111
}
84112

Tests/JExtractSwiftTests/MemoryManagementModeTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct MemoryManagementModeTests {
9999
}
100100
""",
101101
"""
102-
public MyClass f();
102+
public MyClass f(SwiftArena swiftArena$);
103103
"""
104104
]
105105
)

0 commit comments

Comments
 (0)