@@ -7,7 +7,7 @@ import SwiftSyntaxMacros
77// avoids depending on SwiftifyImport.swift
88// all instances are reparsed and reinstantiated by the macro anyways,
99// so linking is irrelevant
10- enum SwiftifyExpr {
10+ enum SwiftifyExpr : Hashable {
1111 case param( _ index: Int )
1212 case `return`
1313}
@@ -21,11 +21,21 @@ extension SwiftifyExpr: CustomStringConvertible {
2121 }
2222}
2323
24+ enum DependenceType {
25+ case borrow, copy
26+ }
27+
28+ struct LifetimeDependence {
29+ let dependsOn : Int
30+ let type : DependenceType
31+ }
32+
2433protocol ParamInfo : CustomStringConvertible {
2534 var description : String { get }
2635 var original : SyntaxProtocol { get }
2736 var pointerIndex : SwiftifyExpr { get }
2837 var nonescaping : Bool { get set }
38+ var dependencies : [ LifetimeDependence ] { get set }
2939
3040 func getBoundsCheckedThunkBuilder(
3141 _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax ,
@@ -55,8 +65,9 @@ func getSwiftifyExprType(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -
5565struct CxxSpan : ParamInfo {
5666 var pointerIndex : SwiftifyExpr
5767 var nonescaping : Bool
58- var original : SyntaxProtocol
68+ var dependencies : [ LifetimeDependence ]
5969 var typeMappings : [ String : String ]
70+ var original : SyntaxProtocol
6071
6172 var description : String {
6273 return " std::span(pointer: \( pointerIndex) , nonescaping: \( nonescaping) ) "
@@ -71,9 +82,8 @@ struct CxxSpan: ParamInfo {
7182 return CxxSpanThunkBuilder ( base: base, index: i - 1 , signature: funcDecl. signature,
7283 typeMappings: typeMappings, node: original, nonescaping: nonescaping)
7384 case . return:
74- // TODO: actually implement std::span in return position
75- return CxxSpanThunkBuilder ( base: base, index: - 1 , signature: funcDecl. signature,
76- typeMappings: typeMappings, node: original, nonescaping: nonescaping)
85+ return CxxSpanReturnThunkBuilder ( base: base, signature: funcDecl. signature,
86+ typeMappings: typeMappings, node: original)
7787 }
7888 }
7989}
@@ -83,6 +93,7 @@ struct CountedBy: ParamInfo {
8393 var count : ExprSyntax
8494 var sizedBy : Bool
8595 var nonescaping : Bool
96+ var dependencies : [ LifetimeDependence ]
8697 var original : SyntaxProtocol
8798
8899 var description : String {
@@ -156,6 +167,8 @@ func getTypeName(_ type: TypeSyntax) throws -> TokenSyntax {
156167 return memberType. name
157168 case . identifierType:
158169 return type. as ( IdentifierTypeSyntax . self) !. name
170+ case . attributedType:
171+ return try getTypeName ( type. as ( AttributedTypeSyntax . self) !. baseType)
159172 default :
160173 throw DiagnosticError ( " expected pointer type, got \( type) with kind \( type. kind) " , node: type)
161174 }
@@ -169,6 +182,13 @@ func replaceTypeName(_ type: TypeSyntax, _ name: TokenSyntax) -> TypeSyntax {
169182 return TypeSyntax ( idType. with ( \. name, name) )
170183}
171184
185+ func replaceBaseType( _ type: TypeSyntax , _ base: TypeSyntax ) -> TypeSyntax {
186+ if let attributedType = type. as ( AttributedTypeSyntax . self) {
187+ return TypeSyntax ( attributedType. with ( \. baseType, base) )
188+ }
189+ return base
190+ }
191+
172192func getPointerMutability( text: String ) -> Mutability ? {
173193 switch text {
174194 case " UnsafePointer " : return . Immutable
@@ -352,7 +372,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
352372 let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
353373 let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
354374 . genericArgumentClause!. arguments. first!. argument) !
355- types [ index] = TypeSyntax ( " Span< \( raw: try getTypeName ( genericArg) . text) > " )
375+ types [ index] = replaceBaseType ( param . type , TypeSyntax ( " Span< \( raw: try getTypeName ( genericArg) . text) > " ) )
356376 return try base. buildFunctionSignature ( types, returnType)
357377 }
358378
@@ -365,6 +385,38 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
365385 }
366386}
367387
388+ struct CxxSpanReturnThunkBuilder : BoundsCheckedThunkBuilder {
389+ public let base : BoundsCheckedThunkBuilder
390+ public let signature : FunctionSignatureSyntax
391+ public let typeMappings : [ String : String ]
392+ public let node : SyntaxProtocol
393+
394+ func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
395+ return [ ]
396+ }
397+
398+ func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
399+ -> FunctionSignatureSyntax {
400+ assert ( returnType == nil )
401+ let typeName = try getTypeName ( signature. returnClause!. type) . text
402+ guard let desugaredType = typeMappings [ typeName] else {
403+ throw DiagnosticError (
404+ " unable to desugar type with name ' \( typeName) ' " , node: node)
405+ }
406+ let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
407+ let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
408+ . genericArgumentClause!. arguments. first!. argument) !
409+ let newType = replaceBaseType ( signature. returnClause!. type,
410+ TypeSyntax ( " Span< \( raw: try getTypeName ( genericArg) . text) > " ) )
411+ return try base. buildFunctionSignature ( argTypes, newType)
412+ }
413+
414+ func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
415+ let call = try base. buildFunctionCall ( pointerArgs)
416+ return " Span(_unsafeCxxSpan: \( call) ) "
417+ }
418+ }
419+
368420protocol PointerBoundsThunkBuilder : BoundsCheckedThunkBuilder {
369421 var oldType : TypeSyntax { get }
370422 var newType : TypeSyntax { get throws }
@@ -723,7 +775,7 @@ public struct SwiftifyImportMacro: PeerMacro {
723775 }
724776 return CountedBy (
725777 pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false ,
726- nonescaping: false , original: ExprSyntax ( enumConstructorExpr) )
778+ nonescaping: false , dependencies : [ ] , original: ExprSyntax ( enumConstructorExpr) )
727779 }
728780
729781 static func parseSizedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ParamInfo {
@@ -738,7 +790,7 @@ public struct SwiftifyImportMacro: PeerMacro {
738790 let unwrappedCountExpr = ExprSyntax ( stringLiteral: sizeExprStringLit. representedLiteralValue!)
739791 return CountedBy (
740792 pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true , nonescaping: false ,
741- original: ExprSyntax ( enumConstructorExpr) )
793+ dependencies : [ ] , original: ExprSyntax ( enumConstructorExpr) )
742794 }
743795
744796 static func parseEndedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ParamInfo {
@@ -758,6 +810,24 @@ public struct SwiftifyImportMacro: PeerMacro {
758810 return pointerParamIndex
759811 }
760812
813+ static func parseLifetimeDependence( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ( SwiftifyExpr , LifetimeDependence ) {
814+ let argumentList = enumConstructorExpr. arguments
815+ let pointer : SwiftifyExpr = try parseSwiftifyExpr ( try getArgumentByName ( argumentList, " pointer " ) )
816+ let dependsOn : Int = try getIntLiteralValue ( try getArgumentByName ( argumentList, " dependsOn " ) )
817+ let type = try getArgumentByName ( argumentList, " type " )
818+ let depType : DependenceType
819+ switch try parseEnumName ( type) {
820+ case " borrow " :
821+ depType = DependenceType . borrow
822+ case " copy " :
823+ depType = DependenceType . copy
824+ default :
825+ throw DiagnosticError ( " expected '.copy' or '.borrow', got ' \( type) ' " , node: type)
826+ }
827+ let dependence = LifetimeDependence ( dependsOn: dependsOn, type: depType)
828+ return ( pointer, dependence)
829+ }
830+
761831 static func parseTypeMappingParam( _ paramAST: LabeledExprSyntax ? ) throws -> [ String : String ] ? {
762832 guard let unwrappedParamAST = paramAST else {
763833 return nil
@@ -786,31 +856,38 @@ public struct SwiftifyImportMacro: PeerMacro {
786856 return dict
787857 }
788858
789- static func parseCxxSpanParams (
859+ static func parseCxxSpansInSignature (
790860 _ signature: FunctionSignatureSyntax ,
791861 _ typeMappings: [ String : String ] ?
792862 ) throws -> [ ParamInfo ] {
793863 guard let typeMappings else {
794864 return [ ]
795865 }
796866 var result : [ ParamInfo ] = [ ]
797- for (idx , param ) in signature . parameterClause . parameters . enumerated ( ) {
798- let typeName = try getTypeName ( param . type) . text;
867+ let process = { type , expr , orig in
868+ let typeName = try getTypeName ( type) . text;
799869 if let desugaredType = typeMappings [ typeName] {
800870 if let unqualifiedDesugaredType = getUnqualifiedStdName ( desugaredType) {
801871 if unqualifiedDesugaredType. starts ( with: " span< " ) {
802- result. append ( CxxSpan ( pointerIndex: . param ( idx + 1 ) , nonescaping: false ,
803- original : param , typeMappings: typeMappings) )
872+ result. append ( CxxSpan ( pointerIndex: expr , nonescaping: false ,
873+ dependencies : [ ] , typeMappings: typeMappings, original : orig ) )
804874 }
805875 }
806876 }
807877 }
878+ for (idx, param) in signature. parameterClause. parameters. enumerated ( ) {
879+ try process ( param. type, . param( idx + 1 ) , param)
880+ }
881+ if let retClause = signature. returnClause {
882+ try process ( retClause. type, . `return`, retClause)
883+ }
808884 return result
809885 }
810886
811887 static func parseMacroParam(
812888 _ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax ,
813- nonescapingPointers: inout Set < Int >
889+ nonescapingPointers: inout Set < Int > ,
890+ lifetimeDependencies: inout [ SwiftifyExpr : [ LifetimeDependence ] ]
814891 ) throws -> ParamInfo ? {
815892 let paramExpr = paramAST. expression
816893 guard let enumConstructorExpr = paramExpr. as ( FunctionCallExprSyntax . self) else {
@@ -826,9 +903,23 @@ public struct SwiftifyImportMacro: PeerMacro {
826903 let index = try parseNonEscaping ( enumConstructorExpr)
827904 nonescapingPointers. insert ( index)
828905 return nil
906+ case " lifetimeDependence " :
907+ let ( expr, dependence) = try parseLifetimeDependence ( enumConstructorExpr)
908+ lifetimeDependencies [ expr, default: [ ] ] . append ( dependence)
909+ // We assume pointers annotated with lifetimebound do not escape.
910+ if dependence. type == DependenceType . copy {
911+ nonescapingPointers. insert ( dependence. dependsOn)
912+ }
913+ // The escaping is controlled when a parameter is the target of a lifetimebound.
914+ // So we want to do the transformation to Swift's Span.
915+ let idx = paramOrReturnIndex ( expr)
916+ if idx != - 1 {
917+ nonescapingPointers. insert ( idx)
918+ }
919+ return nil
829920 default :
830921 throw DiagnosticError (
831- " expected 'countedBy', 'sizedBy', 'endedBy' or 'nonescaping ', got ' \( enumName) ' " ,
922+ " expected 'countedBy', 'sizedBy', 'endedBy', 'nonescaping' or 'lifetimeDependence ', got ' \( enumName) ' " ,
832923 node: enumConstructorExpr)
833924 }
834925 }
@@ -898,11 +989,48 @@ public struct SwiftifyImportMacro: PeerMacro {
898989 }
899990
900991 static func setNonescapingPointers( _ args: inout [ ParamInfo ] , _ nonescapingPointers: Set < Int > ) {
992+ if args. isEmpty {
993+ return
994+ }
901995 for i in 0 ... args. count - 1 where nonescapingPointers. contains ( paramOrReturnIndex ( args [ i] . pointerIndex) ) {
902996 args [ i] . nonescaping = true
903997 }
904998 }
905999
1000+ static func setLifetimeDependencies( _ args: inout [ ParamInfo ] , _ lifetimeDependencies: [ SwiftifyExpr : [ LifetimeDependence ] ] ) {
1001+ if args. isEmpty {
1002+ return
1003+ }
1004+ for i in 0 ... args. count - 1 where lifetimeDependencies. keys. contains ( args [ i] . pointerIndex) {
1005+ args [ i] . dependencies = lifetimeDependencies [ args [ i] . pointerIndex] !
1006+ }
1007+ }
1008+
1009+ static func lifetimeAttributes( _ funcDecl: FunctionDeclSyntax ,
1010+ _ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ] ) -> [ AttributeListSyntax . Element ] {
1011+ let returnDependencies = dependencies [ . `return`, default: [ ] ]
1012+ if returnDependencies. isEmpty {
1013+ return [ ]
1014+ }
1015+ var args : [ LabeledExprSyntax ] = [ ]
1016+ for dependence in returnDependencies {
1017+ if ( dependence. type == . borrow) {
1018+ args. append ( LabeledExprSyntax ( expression:
1019+ DeclReferenceExprSyntax ( baseName: TokenSyntax ( " borrow " ) ) ) )
1020+ }
1021+ args. append ( LabeledExprSyntax ( expression:
1022+ DeclReferenceExprSyntax ( baseName: TokenSyntax ( tryGetParamName ( funcDecl, . param( dependence. dependsOn) ) ) !) ,
1023+ trailingComma: . commaToken( ) ) )
1024+ }
1025+ args [ args. count - 1 ] = args [ args. count - 1 ] . with ( \. trailingComma, nil )
1026+ return [ . attribute( AttributeSyntax (
1027+ atSign: . atSignToken( ) ,
1028+ attributeName: IdentifierTypeSyntax ( name: " lifetime " ) ,
1029+ leftParen: . leftParenToken( ) ,
1030+ arguments: . argumentList( LabeledExprListSyntax ( args) ) ,
1031+ rightParen: . rightParenToken( ) ) ) ]
1032+ }
1033+
9061034 public static func expansion(
9071035 of node: AttributeSyntax ,
9081036 providingPeersOf declaration: some DeclSyntaxProtocol ,
@@ -920,13 +1048,21 @@ public struct SwiftifyImportMacro: PeerMacro {
9201048 arguments = arguments. dropLast ( )
9211049 }
9221050 var nonescapingPointers = Set < Int > ( )
1051+ var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
9231052 var parsedArgs = try arguments. compactMap {
924- try parseMacroParam ( $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers)
1053+ try parseMacroParam ( $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers,
1054+ lifetimeDependencies: & lifetimeDependencies)
9251055 }
926- parsedArgs. append ( contentsOf: try parseCxxSpanParams ( funcDecl. signature, typeMappings) )
1056+ parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl. signature, typeMappings) )
9271057 setNonescapingPointers ( & parsedArgs, nonescapingPointers)
1058+ setLifetimeDependencies ( & parsedArgs, lifetimeDependencies)
1059+ // We only transform non-escaping spans.
9281060 parsedArgs = parsedArgs. filter {
929- !( $0 is CxxSpan ) || ( $0 as! CxxSpan ) . nonescaping
1061+ if let cxxSpanArg = $0 as? CxxSpan {
1062+ return cxxSpanArg. nonescaping || cxxSpanArg. pointerIndex == . return
1063+ } else {
1064+ return true
1065+ }
9301066 }
9311067 try checkArgs ( parsedArgs, funcDecl)
9321068 let baseBuilder = FunctionCallBuilder ( funcDecl)
@@ -951,6 +1087,7 @@ public struct SwiftifyImportMacro: PeerMacro {
9511087 returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
9521088 expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
9531089 let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1090+ let lifetimeAttrs = lifetimeAttributes ( funcDecl, lifetimeDependencies)
9541091 let newFunc =
9551092 funcDecl
9561093 . with ( \. signature, newSignature)
@@ -970,7 +1107,8 @@ public struct SwiftifyImportMacro: PeerMacro {
9701107 AttributeSyntax (
9711108 atSign: . atSignToken( ) ,
9721109 attributeName: IdentifierTypeSyntax ( name: " _alwaysEmitIntoClient " ) ) )
973- ] )
1110+ ]
1111+ + lifetimeAttrs)
9741112 return [ DeclSyntax ( newFunc) ]
9751113 } catch let error as DiagnosticError {
9761114 context. diagnose (
0 commit comments