@@ -4,9 +4,6 @@ import SwiftSyntax
44import SwiftSyntaxBuilder
55import SwiftSyntaxMacros
66
7- // Disable emitting 'MutableSpan' until it has landed
8- let enableMutableSpan = false
9-
107// avoids depending on SwiftifyImport.swift
118// all instances are reparsed and reinstantiated by the macro anyways,
129// so linking is irrelevant
@@ -279,36 +276,49 @@ func getUnqualifiedStdName(_ type: String) -> String? {
279276func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
280277 switch ( mut, generateSpan, isRaw) {
281278 case ( . Immutable, true , true ) : return " RawSpan "
282- case ( . Mutable, true , true ) : return if enableMutableSpan {
283- " MutableRawSpan "
284- } else {
285- " RawSpan "
286- }
279+ case ( . Mutable, true , true ) : return " MutableRawSpan "
287280 case ( . Immutable, false , true ) : return " UnsafeRawBufferPointer "
288281 case ( . Mutable, false , true ) : return " UnsafeMutableRawBufferPointer "
289282
290283 case ( . Immutable, true , false ) : return " Span "
291- case ( . Mutable, true , false ) : return if enableMutableSpan {
292- " MutableSpan "
293- } else {
294- " Span "
295- }
284+ case ( . Mutable, true , false ) : return " MutableSpan "
296285 case ( . Immutable, false , false ) : return " UnsafeBufferPointer "
297286 case ( . Mutable, false , false ) : return " UnsafeMutableBufferPointer "
298287 }
299288}
300289
301- func transformType( _ prev: TypeSyntax , _ generateSpan: Bool , _ isSizedBy: Bool ) throws -> TypeSyntax {
290+ func hasOwnershipSpecifier( _ attrType: AttributedTypeSyntax ) -> Bool {
291+ return attrType. specifiers. contains ( where: { e in
292+ guard let simpleSpec = e. as ( SimpleTypeSpecifierSyntax . self) else {
293+ return false
294+ }
295+ let specifierText = simpleSpec. specifier. text
296+ switch specifierText {
297+ case " borrowing " :
298+ return true
299+ case " inout " :
300+ return true
301+ case " consuming " :
302+ return true
303+ default :
304+ return false
305+ }
306+ } )
307+ }
308+
309+ func transformType( _ prev: TypeSyntax , _ generateSpan: Bool , _ isSizedBy: Bool , _ setMutableSpanInout: Bool ) throws -> TypeSyntax {
302310 if let optType = prev. as ( OptionalTypeSyntax . self) {
303311 return TypeSyntax (
304- optType. with ( \. wrappedType, try transformType ( optType. wrappedType, generateSpan, isSizedBy) ) )
312+ optType. with ( \. wrappedType, try transformType ( optType. wrappedType, generateSpan, isSizedBy, setMutableSpanInout ) ) )
305313 }
306314 if let impOptType = prev. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
307- return try transformType ( impOptType. wrappedType, generateSpan, isSizedBy)
315+ return try transformType ( impOptType. wrappedType, generateSpan, isSizedBy, setMutableSpanInout )
308316 }
309317 if let attrType = prev. as ( AttributedTypeSyntax . self) {
318+ // We insert 'inout' by default for MutableSpan, but it shouldn't override existing ownership
319+ let setMutableSpanInoutNext = setMutableSpanInout && !hasOwnershipSpecifier( attrType)
310320 return TypeSyntax (
311- attrType. with ( \. baseType, try transformType ( attrType. baseType, generateSpan, isSizedBy) ) )
321+ attrType. with ( \. baseType, try transformType ( attrType. baseType, generateSpan, isSizedBy, setMutableSpanInoutNext ) ) )
312322 }
313323 let name = try getTypeName ( prev)
314324 let text = name. text
@@ -326,10 +336,15 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
326336 + " - first type token is ' \( text) ' " , node: name)
327337 }
328338 let token = getSafePointerName ( mut: kind, generateSpan: generateSpan, isRaw: isSizedBy)
329- if isSizedBy {
330- return TypeSyntax ( IdentifierTypeSyntax ( name: token) )
339+ let mainType = if isSizedBy {
340+ TypeSyntax ( IdentifierTypeSyntax ( name: token) )
341+ } else {
342+ try replaceTypeName ( prev, token)
331343 }
332- return try replaceTypeName ( prev, token)
344+ if setMutableSpanInout && generateSpan && kind == . Mutable {
345+ return TypeSyntax ( " inout \( mainType) " )
346+ }
347+ return mainType
333348}
334349
335350func isMutablePointerType( _ type: TypeSyntax ) -> Bool {
@@ -431,10 +446,11 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
431446 let colon : TokenSyntax ? = label != nil ? . colonToken( ) : nil
432447 return LabeledExprSyntax ( label: label, colon: colon, expression: arg, trailingComma: comma)
433448 }
434- return ExprSyntax (
449+ let call = ExprSyntax (
435450 FunctionCallExprSyntax (
436451 calledExpression: functionRef, leftParen: . leftParenToken( ) ,
437452 arguments: LabeledExprListSyntax ( labeledArgs) , rightParen: . rightParenToken( ) ) )
453+ return " unsafe \( call) "
438454 }
439455}
440456
@@ -446,6 +462,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
446462 public let node : SyntaxProtocol
447463 public let nonescaping : Bool
448464 let isSizedBy : Bool = false
465+ let isParameter : Bool = true
449466
450467 func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
451468 return try base. buildBoundsChecks ( )
@@ -462,8 +479,26 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
462479 var args = pointerArgs
463480 let typeName = getUnattributedType ( oldType) . description
464481 assert ( args [ index] == nil )
465- args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: name) ) " )
466- return try base. buildFunctionCall ( args)
482+
483+ let ( _, isConst) = dropCxxQualifiers ( try genericArg)
484+ if isConst {
485+ args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: name) ) " )
486+ return try base. buildFunctionCall ( args)
487+ } else {
488+ let unwrappedName = TokenSyntax ( " _ \( name) Ptr " )
489+ args [ index] = ExprSyntax ( " \( raw: typeName) ( \( unwrappedName) ) " )
490+ let call = try base. buildFunctionCall ( args)
491+
492+ // MutableSpan - unlike Span - cannot be bitcast to std::span due to being ~Copyable,
493+ // so unwrap it to an UnsafeMutableBufferPointer that we can cast
494+ let unwrappedCall = ExprSyntax (
495+ """
496+ unsafe \( name) .withUnsafeMutableBufferPointer { \( unwrappedName) in
497+ return \( call)
498+ }
499+ """ )
500+ return unwrappedCall
501+ }
467502 }
468503}
469504
@@ -472,6 +507,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
472507 public let signature : FunctionSignatureSyntax
473508 public let typeMappings : [ String : String ]
474509 public let node : SyntaxProtocol
510+ let isParameter : Bool = false
475511
476512 var oldType : TypeSyntax {
477513 return signature. returnClause!. type
@@ -490,12 +526,12 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
490526 func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
491527 let call = try base. buildFunctionCall ( pointerArgs)
492528 let ( _, isConst) = dropCxxQualifiers ( try genericArg)
493- let cast = if isConst || !enableMutableSpan {
529+ let cast = if isConst {
494530 " Span "
495531 } else {
496532 " MutableSpan "
497533 }
498- return " _cxxOverrideLifetime( \( raw: cast) (_unsafeCxxSpan: \( call) ), copying: ()) "
534+ return " unsafe _cxxOverrideLifetime(\( raw: cast) (_unsafeCxxSpan: \( call) ), copying: ()) "
499535 }
500536}
501537
@@ -508,11 +544,12 @@ protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
508544protocol SpanBoundsThunkBuilder : BoundsThunkBuilder {
509545 var typeMappings : [ String : String ] { get }
510546 var node : SyntaxProtocol { get }
547+ var isParameter : Bool { get }
511548}
512549extension SpanBoundsThunkBuilder {
513550 var desugaredType : TypeSyntax {
514551 get throws {
515- let typeName = try getUnattributedType ( oldType) . description
552+ let typeName = getUnattributedType ( oldType) . description
516553 guard let desugaredTypeName = typeMappings [ typeName] else {
517554 throw DiagnosticError (
518555 " unable to desugar type with name ' \( typeName) ' " , node: node)
@@ -547,14 +584,18 @@ extension SpanBoundsThunkBuilder {
547584 var newType : TypeSyntax {
548585 get throws {
549586 let ( strippedArg, isConst) = dropCxxQualifiers ( try genericArg)
550- let mutablePrefix = if isConst || !enableMutableSpan {
587+ let mutablePrefix = if isConst {
551588 " "
552589 } else {
553590 " Mutable "
554591 }
555- return replaceBaseType (
592+ let mainType = replaceBaseType (
556593 oldType,
557594 TypeSyntax ( " \( raw: mutablePrefix) Span< \( raw: strippedArg) > " ) )
595+ if !isConst && isParameter {
596+ return TypeSyntax ( " inout \( mainType) " )
597+ }
598+ return mainType
558599 }
559600 }
560601}
@@ -563,13 +604,14 @@ protocol PointerBoundsThunkBuilder: BoundsThunkBuilder {
563604 var nullable : Bool { get }
564605 var isSizedBy : Bool { get }
565606 var generateSpan : Bool { get }
607+ var isParameter : Bool { get }
566608}
567609
568610extension PointerBoundsThunkBuilder {
569611 var nullable : Bool { return oldType. is ( OptionalTypeSyntax . self) }
570612
571613 var newType : TypeSyntax { get throws {
572- return try transformType ( oldType, generateSpan, isSizedBy) }
614+ return try transformType ( oldType, generateSpan, isSizedBy, isParameter ) }
573615 }
574616}
575617
@@ -599,8 +641,9 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
599641 public let nonescaping : Bool
600642 public let isSizedBy : Bool
601643 public let dependencies : [ LifetimeDependence ]
644+ let isParameter : Bool = false
602645
603- var generateSpan : Bool { !dependencies. isEmpty && ( !isMutablePointerType ( oldType ) || enableMutableSpan ) }
646+ var generateSpan : Bool { !dependencies. isEmpty }
604647
605648 var oldType : TypeSyntax {
606649 return signature. returnClause!. type
@@ -623,9 +666,25 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
623666 } else {
624667 " start "
625668 }
669+ var cast = try newType
670+ if nullable {
671+ if let optType = cast. as ( OptionalTypeSyntax . self) {
672+ cast = optType. wrappedType
673+ }
674+ return """
675+ { () in
676+ let _resultValue = \( call)
677+ if unsafe _resultValue == nil {
678+ return nil
679+ } else {
680+ return unsafe \( raw: try cast) ( \( raw: startLabel) : _resultValue!, count: Int( \( countExpr) ))
681+ }
682+ }()
683+ """
684+ }
626685 return
627686 """
628- \( raw: try newType ) ( \( raw: startLabel) : \( call) , count: Int( \( countExpr) ))
687+ unsafe \( raw: try cast ) ( \( raw: startLabel) : \( call) , count: Int( \( countExpr) ))
629688 """
630689 }
631690}
@@ -639,8 +698,9 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
639698 public let nonescaping : Bool
640699 public let isSizedBy : Bool
641700 public let skipTrivialCount : Bool
701+ let isParameter : Bool = true
642702
643- var generateSpan : Bool { nonescaping && ( !isMutablePointerType ( oldType ) || enableMutableSpan ) }
703+ var generateSpan : Bool { nonescaping }
644704
645705 func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
646706 -> ( FunctionSignatureSyntax , Bool ) {
@@ -702,11 +762,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
702762 let call = try base. buildFunctionCall ( args)
703763 let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: name) ) )
704764
705- let funcName = isSizedBy ? " withUnsafeBytes " : " withUnsafeBufferPointer "
765+ let funcName = switch ( isSizedBy, isMutablePointerType ( oldType) ) {
766+ case ( true , true ) : " withUnsafeMutableBytes "
767+ case ( true , false ) : " withUnsafeBytes "
768+ case ( false , true ) : " withUnsafeMutableBufferPointer "
769+ case ( false , false ) : " withUnsafeBufferPointer "
770+ }
706771 let unwrappedCall = ExprSyntax (
707772 """
708- \( ptrRef) . \( raw: funcName) { \( unwrappedName) in
709- return unsafe \( call)
773+ unsafe \( ptrRef) . \( raw: funcName) { \( unwrappedName) in
774+ return \( call)
710775 }
711776 """ )
712777 return unwrappedCall
@@ -766,11 +831,11 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
766831 nullArgs [ index] = ExprSyntax ( NilLiteralExprSyntax ( nilKeyword: . keyword( . nil ) ) )
767832 return ExprSyntax (
768833 """
769- if \( name) == nil {
770- unsafe \( try base. buildFunctionCall ( nullArgs) )
771- } else {
772- \( unwrappedCall)
773- }
834+ { () in return if \( name) == nil {
835+ \( try base. buildFunctionCall ( nullArgs) )
836+ } else {
837+ \( unwrappedCall)
838+ } }()
774839 """ )
775840 }
776841 return unwrappedCall
@@ -1161,7 +1226,7 @@ public struct SwiftifyImportMacro: PeerMacro {
11611226 }
11621227 }
11631228
1164- static func lifetimeAttributes ( _ funcDecl: FunctionDeclSyntax ,
1229+ static func getReturnLifetimeAttribute ( _ funcDecl: FunctionDeclSyntax ,
11651230 _ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ] ) -> [ AttributeListSyntax . Element ] {
11661231 let returnDependencies = dependencies [ . `return`, default: [ ] ]
11671232 if returnDependencies. isEmpty {
@@ -1190,6 +1255,66 @@ public struct SwiftifyImportMacro: PeerMacro {
11901255 rightParen: . rightParenToken( ) ) ) ]
11911256 }
11921257
1258+ static func isMutableSpan( _ type: TypeSyntax ) -> Bool {
1259+ if let optType = type. as ( OptionalTypeSyntax . self) {
1260+ return isMutableSpan ( optType. wrappedType)
1261+ }
1262+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
1263+ return isMutableSpan ( impOptType. wrappedType)
1264+ }
1265+ if let attrType = type. as ( AttributedTypeSyntax . self) {
1266+ return isMutableSpan ( attrType. baseType)
1267+ }
1268+ guard let identifierType = type. as ( IdentifierTypeSyntax . self) else {
1269+ return false
1270+ }
1271+ let name = identifierType. name. text
1272+ return name == " MutableSpan " || name == " MutableRawSpan "
1273+ }
1274+
1275+ static func containsLifetimeAttr( _ attrs: AttributeListSyntax , for paramName: TokenSyntax ) -> Bool {
1276+ for elem in attrs {
1277+ guard let attr = elem. as ( AttributeSyntax . self) else {
1278+ continue
1279+ }
1280+ if attr. attributeName != " lifetime " {
1281+ continue
1282+ }
1283+ guard let args = attr. arguments? . as ( LabeledExprListSyntax . self) else {
1284+ continue
1285+ }
1286+ for arg in args {
1287+ if arg. label == paramName {
1288+ return true
1289+ }
1290+ }
1291+ }
1292+ return false
1293+ }
1294+
1295+ // Mutable[Raw]Span parameters need explicit @lifetime annotations since they are inout
1296+ static func paramLifetimeAttributes( _ newSignature: FunctionSignatureSyntax , _ oldAttrs: AttributeListSyntax ) -> [ AttributeListSyntax . Element ] {
1297+ var defaultLifetimes : [ AttributeListSyntax . Element ] = [ ]
1298+ for param in newSignature. parameterClause. parameters {
1299+ if !isMutableSpan( param. type) {
1300+ continue
1301+ }
1302+ let paramName = param. secondName ?? param. firstName
1303+ if containsLifetimeAttr ( oldAttrs, for: paramName) {
1304+ continue
1305+ }
1306+ let expr = ExprSyntax ( " \( paramName) : copy \( paramName) " )
1307+
1308+ defaultLifetimes. append ( . attribute( AttributeSyntax (
1309+ atSign: . atSignToken( ) ,
1310+ attributeName: IdentifierTypeSyntax ( name: " lifetime " ) ,
1311+ leftParen: . leftParenToken( ) ,
1312+ arguments: . argumentList( LabeledExprListSyntax ( [ LabeledExprSyntax ( expression: expr) ] ) ) ,
1313+ rightParen: . rightParenToken( ) ) ) )
1314+ }
1315+ return defaultLifetimes
1316+ }
1317+
11931318 public static func expansion(
11941319 of node: AttributeSyntax ,
11951320 providingPeersOf declaration: some DeclSyntaxProtocol ,
@@ -1255,9 +1380,10 @@ public struct SwiftifyImportMacro: PeerMacro {
12551380 item: CodeBlockItemSyntax . Item (
12561381 ReturnStmtSyntax (
12571382 returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1258- expression: ExprSyntax ( " unsafe \( try builder. buildFunctionCall ( [ : ] ) ) " ) ) ) )
1383+ expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
12591384 let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1260- let lifetimeAttrs = lifetimeAttributes ( funcDecl, lifetimeDependencies)
1385+ let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcDecl, lifetimeDependencies)
1386+ let lifetimeAttrs = returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcDecl. attributes)
12611387 let disfavoredOverload : [ AttributeListSyntax . Element ] = ( onlyReturnTypeChanged ? [
12621388 . attribute(
12631389 AttributeSyntax (
0 commit comments