@@ -45,19 +45,11 @@ protocol ParamInfo: CustomStringConvertible {
4545 ) -> BoundsCheckedThunkBuilder
4646}
4747
48- func getParamName( _ param: FunctionParameterSyntax , _ paramIndex: Int ) -> TokenSyntax {
49- let name = param. secondName ?? param. firstName
50- if name. trimmed. text == " _ " {
51- return " _param \( raw: paramIndex) "
52- }
53- return name
54- }
55-
5648func tryGetParamName( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
5749 switch expr {
5850 case . param( let i) :
5951 let funcParam = getParam ( funcDecl, i - 1 )
60- return getParamName ( funcParam, i - 1 )
52+ return funcParam. name
6153 case . `self`:
6254 return . keyword( . self )
6355 default : return nil
@@ -427,12 +419,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
427419 // filter out deleted parameters, i.e. ones where argTypes[i] _contains_ nil
428420 return type == nil || type! != nil
429421 } . map { ( i: Int , e: FunctionParameterSyntax ) in
430- let param = e. with ( \. type, ( argTypes [ i] ?? e. type) !)
431- let name = param. secondName ?? param. firstName
432- if name. trimmed. text == " _ " {
433- return param. with ( \. secondName, getParamName ( param, i) )
434- }
435- return param
422+ e. with ( \. type, ( argTypes [ i] ?? e. type) !)
436423 }
437424 if let last = newParams. popLast ( ) {
438425 newParams. append ( last. with ( \. trailingComma, nil ) )
@@ -450,9 +437,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
450437 let functionRef = DeclReferenceExprSyntax ( baseName: base. name)
451438 let args : [ ExprSyntax ] = base. signature. parameterClause. parameters. enumerated ( )
452439 . map { ( i: Int , param: FunctionParameterSyntax ) in
453- let name = getParamName ( param, i)
454- let declref = DeclReferenceExprSyntax ( baseName: name)
455- return pointerArgs [ i] ?? ExprSyntax ( declref)
440+ return pointerArgs [ i] ?? ExprSyntax ( " \( param. name) " )
456441 }
457442 let labels : [ TokenSyntax ? ] = base. signature. parameterClause. parameters. map { param in
458443 let firstName = param. firstName. trimmed
@@ -468,7 +453,8 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
468453 comma = . commaToken( )
469454 }
470455 let colon : TokenSyntax ? = label != nil ? . colonToken( ) : nil
471- return LabeledExprSyntax ( label: label, colon: colon, expression: arg, trailingComma: comma)
456+ // The compiler emits warnings if you unnecessarily escape labels in function calls
457+ return LabeledExprSyntax ( label: label? . withoutBackticks, colon: colon, expression: arg, trailingComma: comma)
472458 }
473459 let call = ExprSyntax (
474460 FunctionCallExprSyntax (
@@ -510,7 +496,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
510496 args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: name) ) " )
511497 return try base. buildFunctionCall ( args)
512498 } else {
513- let unwrappedName = TokenSyntax ( " _ \( name) Ptr " )
499+ let unwrappedName = TokenSyntax ( " _ \( name. withoutBackticks ) Ptr " )
514500 args [ index] = ExprSyntax ( " \( raw: typeName) ( \( unwrappedName) ) " )
515501 let call = try base. buildFunctionCall ( args)
516502
@@ -663,7 +649,7 @@ extension ParamBoundsThunkBuilder {
663649 }
664650
665651 var name : TokenSyntax {
666- getParamName ( param, index )
652+ param. name
667653 }
668654}
669655
@@ -796,7 +782,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
796782 }
797783
798784 func buildUnwrapCall( _ argOverrides: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
799- let unwrappedName = TokenSyntax ( " _ \( name) Ptr " )
785+ let unwrappedName = TokenSyntax ( " _ \( name. withoutBackticks ) Ptr " ) . escapeIfNeeded
800786 var args = argOverrides
801787 let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
802788 assert ( args [ index] == nil )
@@ -809,7 +795,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
809795 }
810796 }
811797 let call = try base. buildFunctionCall ( args)
812- let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName : name) ) )
798+ let ptrRef = unwrapIfNullable ( " \( name) " )
813799
814800 let funcName =
815801 switch ( isSizedBy, isMutablePointerType ( oldType) ) {
@@ -1004,7 +990,7 @@ func parseSwiftifyExpr(_ expr: ExprSyntax) throws -> SwiftifyExpr {
1004990}
1005991
1006992func parseCountedByEnum(
1007- _ enumConstructorExpr: FunctionCallExprSyntax , _ signature: FunctionSignatureSyntax
993+ _ enumConstructorExpr: FunctionCallExprSyntax , _ signature: FunctionSignatureSyntax , _ rewriter : CountExprRewriter
1008994) throws -> ParamInfo {
1009995 let argumentList = enumConstructorExpr. arguments
1010996 let pointerExprArg = try getArgumentByName ( argumentList, " pointer " )
@@ -1015,7 +1001,8 @@ func parseCountedByEnum(
10151001 " expected string literal for 'count' parameter, got \( countExprArg) " , node: countExprArg)
10161002 }
10171003 let unwrappedCountExpr = ExprSyntax ( stringLiteral: countExprStringLit. representedLiteralValue!)
1018- if let countVar = unwrappedCountExpr. as ( DeclReferenceExprSyntax . self) {
1004+ let rewrittenCountExpr = rewriter. visit ( unwrappedCountExpr)
1005+ if let countVar = rewrittenCountExpr. as ( DeclReferenceExprSyntax . self) {
10191006 // Perform this lookup here so we can override the position to point to the string literal
10201007 // instead of line 1, column 1
10211008 do {
@@ -1025,11 +1012,11 @@ func parseCountedByEnum(
10251012 }
10261013 }
10271014 return CountedBy (
1028- pointerIndex: pointerExpr, count: unwrappedCountExpr , sizedBy: false ,
1015+ pointerIndex: pointerExpr, count: rewrittenCountExpr , sizedBy: false ,
10291016 nonescaping: false , dependencies: [ ] , original: ExprSyntax ( enumConstructorExpr) )
10301017}
10311018
1032- func parseSizedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ParamInfo {
1019+ func parseSizedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax , _ rewriter : CountExprRewriter ) throws -> ParamInfo {
10331020 let argumentList = enumConstructorExpr. arguments
10341021 let pointerExprArg = try getArgumentByName ( argumentList, " pointer " )
10351022 let pointerExpr : SwiftifyExpr = try parseSwiftifyExpr ( pointerExprArg)
@@ -1039,8 +1026,9 @@ func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> P
10391026 " expected string literal for 'size' parameter, got \( sizeExprArg) " , node: sizeExprArg)
10401027 }
10411028 let unwrappedCountExpr = ExprSyntax ( stringLiteral: sizeExprStringLit. representedLiteralValue!)
1029+ let rewrittenCountExpr = rewriter. visit ( unwrappedCountExpr)
10421030 return CountedBy (
1043- pointerIndex: pointerExpr, count: unwrappedCountExpr , sizedBy: true , nonescaping: false ,
1031+ pointerIndex: pointerExpr, count: rewrittenCountExpr , sizedBy: true , nonescaping: false ,
10441032 dependencies: [ ] , original: ExprSyntax ( enumConstructorExpr) )
10451033}
10461034
@@ -1177,7 +1165,7 @@ func parseCxxSpansInSignature(
11771165}
11781166
11791167func parseMacroParam(
1180- _ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax ,
1168+ _ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax , _ rewriter : CountExprRewriter ,
11811169 nonescapingPointers: inout Set < Int > ,
11821170 lifetimeDependencies: inout [ SwiftifyExpr : [ LifetimeDependence ] ]
11831171) throws -> ParamInfo ? {
@@ -1188,8 +1176,8 @@ func parseMacroParam(
11881176 }
11891177 let enumName = try parseEnumName ( paramExpr)
11901178 switch enumName {
1191- case " countedBy " : return try parseCountedByEnum ( enumConstructorExpr, signature)
1192- case " sizedBy " : return try parseSizedByEnum ( enumConstructorExpr)
1179+ case " countedBy " : return try parseCountedByEnum ( enumConstructorExpr, signature, rewriter )
1180+ case " sizedBy " : return try parseSizedByEnum ( enumConstructorExpr, rewriter )
11931181 case " endedBy " : return try parseEndedByEnum ( enumConstructorExpr)
11941182 case " nonescaping " :
11951183 let index = try parseNonEscaping ( enumConstructorExpr)
@@ -1438,7 +1426,7 @@ func paramLifetimeAttributes(
14381426 if !isMutableSpan( param. type) {
14391427 continue
14401428 }
1441- let paramName = param. secondName ?? param . firstName
1429+ let paramName = param. name
14421430 if containsLifetimeAttr ( oldAttrs, for: paramName) {
14431431 continue
14441432 }
@@ -1456,6 +1444,61 @@ func paramLifetimeAttributes(
14561444 return defaultLifetimes
14571445}
14581446
1447+ class CountExprRewriter : SyntaxRewriter {
1448+ public let nameMap : [ String : String ]
1449+
1450+ init ( _ renamedParams: [ String : String ] ) {
1451+ nameMap = renamedParams
1452+ }
1453+
1454+ override func visit( _ node: DeclReferenceExprSyntax ) -> ExprSyntax {
1455+ if let newName = nameMap [ node. baseName. trimmed. text] {
1456+ return ExprSyntax (
1457+ node. with (
1458+ \. baseName,
1459+ . identifier(
1460+ newName, leadingTrivia: node. baseName. leadingTrivia,
1461+ trailingTrivia: node. baseName. trailingTrivia) ) )
1462+ }
1463+ return escapeIfNeeded ( node)
1464+ }
1465+ }
1466+
1467+ func renameParameterNamesIfNeeded( _ funcDecl: FunctionDeclSyntax ) -> ( FunctionDeclSyntax , CountExprRewriter ) {
1468+ let params = funcDecl. signature. parameterClause. parameters
1469+ let funcName = funcDecl. name. withoutBackticks. trimmed. text
1470+ let shouldRename = params. contains ( where: { param in
1471+ let paramName = param. name. trimmed. text
1472+ return paramName == " _ " || paramName == funcName || " ` \( paramName) ` " == funcName
1473+ } )
1474+ var renamedParams : [ String : String ] = [ : ]
1475+ let newParams = params. enumerated ( ) . map { ( i, param) in
1476+ let secondName = if shouldRename {
1477+ // Including funcName in name prevents clash with function name.
1478+ // Renaming all parameters if one requires renaming guarantees that other parameters don't clash with the renamed one.
1479+ TokenSyntax ( " _ \( raw: funcName) _param \( raw: i) " )
1480+ } else {
1481+ param. secondName? . escapeIfNeeded
1482+ }
1483+ let firstName = param. firstName. escapeIfNeeded
1484+ let newParam = param. with ( \. secondName, secondName)
1485+ . with ( \. firstName, firstName)
1486+ let newName = newParam. name. trimmed. text
1487+ let oldName = param. name. trimmed. text
1488+ if newName != oldName {
1489+ renamedParams [ oldName] = newName
1490+ }
1491+ return newParam
1492+ }
1493+ let newDecl = if renamedParams. count > 0 {
1494+ funcDecl. with ( \. signature. parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1495+ } else {
1496+ // Keeps source locations for diagnostics, in the common case where nothing was renamed
1497+ funcDecl
1498+ }
1499+ return ( newDecl, CountExprRewriter ( renamedParams) )
1500+ }
1501+
14591502/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
14601503/// Depends on bounds, escapability and lifetime information for each pointer.
14611504/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
@@ -1469,9 +1512,10 @@ public struct SwiftifyImportMacro: PeerMacro {
14691512 in context: some MacroExpansionContext
14701513 ) throws -> [ DeclSyntax ] {
14711514 do {
1472- guard let funcDecl = declaration. as ( FunctionDeclSyntax . self) else {
1515+ guard let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) else {
14731516 throw DiagnosticError ( " @_SwiftifyImport only works on functions " , node: declaration)
14741517 }
1518+ let ( funcDecl, rewriter) = renameParameterNamesIfNeeded ( origFuncDecl)
14751519
14761520 let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
14771521 var arguments = [ LabeledExprSyntax] ( argumentList)
@@ -1487,7 +1531,7 @@ public struct SwiftifyImportMacro: PeerMacro {
14871531 var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
14881532 var parsedArgs = try arguments. compactMap {
14891533 try parseMacroParam (
1490- $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers,
1534+ $0, funcDecl. signature, rewriter , nonescapingPointers: & nonescapingPointers,
14911535 lifetimeDependencies: & lifetimeDependencies)
14921536 }
14931537 parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl. signature, typeMappings) )
@@ -1627,3 +1671,33 @@ extension TypeSyntaxProtocol {
16271671 return false
16281672 }
16291673}
1674+
1675+ extension FunctionParameterSyntax {
1676+ var name : TokenSyntax {
1677+ self . secondName ?? self . firstName
1678+ }
1679+ }
1680+
1681+ extension TokenSyntax {
1682+ public var withoutBackticks : TokenSyntax {
1683+ return . identifier( self . identifier!. name)
1684+ }
1685+
1686+ public var escapeIfNeeded : TokenSyntax {
1687+ var parser = Parser ( " let \( self ) " )
1688+ let decl = DeclSyntax . parse ( from: & parser)
1689+ if !decl. hasError {
1690+ return self
1691+ } else {
1692+ return self . copyTrivia ( to: " ` \( raw: self . trimmed. text) ` " )
1693+ }
1694+ }
1695+
1696+ public func copyTrivia( to other: TokenSyntax ) -> TokenSyntax {
1697+ return . identifier( other. text, leadingTrivia: self . leadingTrivia, trailingTrivia: self . trailingTrivia)
1698+ }
1699+ }
1700+
1701+ func escapeIfNeeded( _ identifier: DeclReferenceExprSyntax ) -> ExprSyntax {
1702+ return " \( identifier. baseName. escapeIfNeeded) "
1703+ }
0 commit comments