@@ -4,6 +4,9 @@ import SwiftSyntax
44import SwiftSyntaxBuilder
55import SwiftSyntaxMacros
66
7+ // Disable emitting 'MutableSpan' until it has landed
8+ let enableMutableSpan = false
9+
710// avoids depending on SwiftifyImport.swift
811// all instances are reparsed and reinstantiated by the macro anyways,
912// so linking is irrelevant
@@ -213,22 +216,26 @@ func replaceBaseType(_ type: TypeSyntax, _ base: TypeSyntax) -> TypeSyntax {
213216
214217// C++ type qualifiers, `const T` and `volatile T`, are encoded as fake generic
215218// types, `__cxxConst<T>` and `__cxxVolatile<T>` respectively. Remove those.
216- func dropQualifierGenerics( _ type: TypeSyntax ) -> TypeSyntax {
217- guard let identifier = type. as ( IdentifierTypeSyntax . self) else { return type }
218- guard let generic = identifier. genericArgumentClause else { return type }
219- guard let genericArg = generic. arguments. first else { return type }
220- guard case . type( let argType) = genericArg. argument else { return type }
219+ // Second return value is true if __cxxConst was stripped.
220+ func dropQualifierGenerics( _ type: TypeSyntax ) -> ( TypeSyntax , Bool ) {
221+ guard let identifier = type. as ( IdentifierTypeSyntax . self) else { return ( type, false ) }
222+ guard let generic = identifier. genericArgumentClause else { return ( type, false ) }
223+ guard let genericArg = generic. arguments. first else { return ( type, false ) }
224+ guard case . type( let argType) = genericArg. argument else { return ( type, false ) }
221225 switch identifier. name. text {
222- case " __cxxConst " , " __cxxVolatile " :
226+ case " __cxxConst " :
227+ let ( retType, _) = dropQualifierGenerics ( argType)
228+ return ( retType, true )
229+ case " __cxxVolatile " :
223230 return dropQualifierGenerics ( argType)
224231 default :
225- return type
232+ return ( type, false )
226233 }
227234}
228235
229236// The generated type names for template instantiations sometimes contain
230237// encoded qualifiers for disambiguation purposes. We need to remove those.
231- func dropCxxQualifiers( _ type: TypeSyntax ) -> TypeSyntax {
238+ func dropCxxQualifiers( _ type: TypeSyntax ) -> ( TypeSyntax , Bool ) {
232239 if let attributed = type. as ( AttributedTypeSyntax . self) {
233240 return dropCxxQualifiers ( attributed. baseType)
234241 }
@@ -272,12 +279,20 @@ func getUnqualifiedStdName(_ type: String) -> String? {
272279func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
273280 switch ( mut, generateSpan, isRaw) {
274281 case ( . Immutable, true , true ) : return " RawSpan "
275- case ( . Mutable, true , true ) : return " MutableRawSpan "
282+ case ( . Mutable, true , true ) : return if enableMutableSpan {
283+ " MutableRawSpan "
284+ } else {
285+ " RawSpan "
286+ }
276287 case ( . Immutable, false , true ) : return " UnsafeRawBufferPointer "
277288 case ( . Mutable, false , true ) : return " UnsafeMutableRawBufferPointer "
278289
279290 case ( . Immutable, true , false ) : return " Span "
280- case ( . Mutable, true , false ) : return " MutableSpan "
291+ case ( . Mutable, true , false ) : return if enableMutableSpan {
292+ " MutableSpan "
293+ } else {
294+ " Span "
295+ }
281296 case ( . Immutable, false , false ) : return " UnsafeBufferPointer "
282297 case ( . Mutable, false , false ) : return " UnsafeMutableBufferPointer "
283298 }
@@ -317,6 +332,28 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
317332 return try replaceTypeName ( prev, token)
318333}
319334
335+ func isMutablePointerType( _ type: TypeSyntax ) -> Bool {
336+ if let optType = type. as ( OptionalTypeSyntax . self) {
337+ return isMutablePointerType ( optType. wrappedType)
338+ }
339+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
340+ return isMutablePointerType ( impOptType. wrappedType)
341+ }
342+ if let attrType = type. as ( AttributedTypeSyntax . self) {
343+ return isMutablePointerType ( attrType. baseType)
344+ }
345+ do {
346+ let name = try getTypeName ( type)
347+ let text = name. text
348+ guard let kind: Mutability = getPointerMutability ( text: text) else {
349+ return false
350+ }
351+ return kind == . Mutable
352+ } catch _ {
353+ return false
354+ }
355+ }
356+
320357protocol BoundsCheckedThunkBuilder {
321358 func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax
322359 func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ]
@@ -401,7 +438,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
401438 }
402439}
403440
404- struct CxxSpanThunkBuilder : ParamPointerBoundsThunkBuilder {
441+ struct CxxSpanThunkBuilder : SpanBoundsThunkBuilder , ParamBoundsThunkBuilder {
405442 public let base : BoundsCheckedThunkBuilder
406443 public let index : Int
407444 public let signature : FunctionSignatureSyntax
@@ -417,17 +454,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
417454 func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
418455 -> ( FunctionSignatureSyntax , Bool ) {
419456 var types = argTypes
420- let typeName = getUnattributedType ( oldType) . description
421- guard let desugaredType = typeMappings [ typeName] else {
422- throw DiagnosticError (
423- " unable to desugar type with name ' \( typeName) ' " , node: node)
424- }
425-
426- let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
427- let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
428- . genericArgumentClause!. arguments. first!. argument) !
429- types [ index] = replaceBaseType ( param. type,
430- TypeSyntax ( " Span< \( raw: dropCxxQualifiers ( genericArg) ) > " ) )
457+ types [ index] = try newType
431458 return try base. buildFunctionSignature ( types, returnType)
432459 }
433460
@@ -440,44 +467,100 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
440467 }
441468}
442469
443- struct CxxSpanReturnThunkBuilder : BoundsCheckedThunkBuilder {
470+ struct CxxSpanReturnThunkBuilder : SpanBoundsThunkBuilder {
444471 public let base : BoundsCheckedThunkBuilder
445472 public let signature : FunctionSignatureSyntax
446473 public let typeMappings : [ String : String ]
447474 public let node : SyntaxProtocol
448475
476+ var oldType : TypeSyntax {
477+ return signature. returnClause!. type
478+ }
479+
449480 func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
450481 return try base. buildBoundsChecks ( )
451482 }
452483
453484 func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
454485 -> ( FunctionSignatureSyntax , Bool ) {
455486 assert ( returnType == nil )
456- let typeName = getUnattributedType ( signature. returnClause!. type) . description
457- guard let desugaredType = typeMappings [ typeName] else {
458- throw DiagnosticError (
459- " unable to desugar type with name ' \( typeName) ' " , node: node)
460- }
461- let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
462- let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
463- . genericArgumentClause!. arguments. first!. argument) !
464- let newType = replaceBaseType ( signature. returnClause!. type,
465- TypeSyntax ( " Span< \( raw: dropCxxQualifiers ( genericArg) ) > " ) )
466487 return try base. buildFunctionSignature ( argTypes, newType)
467488 }
468489
469490 func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
470491 let call = try base. buildFunctionCall ( pointerArgs)
471- return " _cxxOverrideLifetime(Span(_unsafeCxxSpan: \( call) ), copying: ()) "
492+ let ( _, isConst) = dropCxxQualifiers ( try genericArg)
493+ let cast = if isConst || !enableMutableSpan {
494+ " Span "
495+ } else {
496+ " MutableSpan "
497+ }
498+ return " _cxxOverrideLifetime( \( raw: cast) (_unsafeCxxSpan: \( call) ), copying: ()) "
472499 }
473500}
474501
475- protocol PointerBoundsThunkBuilder : BoundsCheckedThunkBuilder {
502+ protocol BoundsThunkBuilder : BoundsCheckedThunkBuilder {
476503 var oldType : TypeSyntax { get }
477504 var newType : TypeSyntax { get throws }
478- var nullable : Bool { get }
479505 var signature : FunctionSignatureSyntax { get }
480- var nonescaping : Bool { get }
506+ }
507+
508+ protocol SpanBoundsThunkBuilder : BoundsThunkBuilder {
509+ var typeMappings : [ String : String ] { get }
510+ var node : SyntaxProtocol { get }
511+ }
512+ extension SpanBoundsThunkBuilder {
513+ var desugaredType : TypeSyntax {
514+ get throws {
515+ let typeName = try getUnattributedType ( oldType) . description
516+ guard let desugaredTypeName = typeMappings [ typeName] else {
517+ throw DiagnosticError (
518+ " unable to desugar type with name ' \( typeName) ' " , node: node)
519+ }
520+ return TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredTypeName) !) " )
521+ }
522+ }
523+ var genericArg : TypeSyntax {
524+ get throws {
525+ guard let idType = try desugaredType. as ( IdentifierTypeSyntax . self) else {
526+ throw DiagnosticError (
527+ " unexpected non-identifier type ' \( try desugaredType) ', expected a std::span type " ,
528+ node: try desugaredType)
529+ }
530+ guard let genericArgumentClause = idType. genericArgumentClause else {
531+ throw DiagnosticError (
532+ " missing generic type argument clause expected after \( idType) " , node: idType)
533+ }
534+ guard let firstArg = genericArgumentClause. arguments. first else {
535+ throw DiagnosticError (
536+ " expected at least 1 generic type argument for std::span type ' \( idType) ', found ' \( genericArgumentClause) ' " ,
537+ node: genericArgumentClause. arguments)
538+ }
539+ guard let arg = TypeSyntax ( firstArg. argument) else {
540+ throw DiagnosticError (
541+ " invalid generic type argument ' \( firstArg. argument) ' " ,
542+ node: firstArg. argument)
543+ }
544+ return arg
545+ }
546+ }
547+ var newType : TypeSyntax {
548+ get throws {
549+ let ( strippedArg, isConst) = dropCxxQualifiers ( try genericArg)
550+ let mutablePrefix = if isConst || !enableMutableSpan {
551+ " "
552+ } else {
553+ " Mutable "
554+ }
555+ return replaceBaseType (
556+ oldType,
557+ TypeSyntax ( " \( raw: mutablePrefix) Span< \( raw: strippedArg) > " ) )
558+ }
559+ }
560+ }
561+
562+ protocol PointerBoundsThunkBuilder : BoundsThunkBuilder {
563+ var nullable : Bool { get }
481564 var isSizedBy : Bool { get }
482565 var generateSpan : Bool { get }
483566}
@@ -490,13 +573,12 @@ extension PointerBoundsThunkBuilder {
490573 }
491574}
492575
493- protocol ParamPointerBoundsThunkBuilder : PointerBoundsThunkBuilder {
576+ protocol ParamBoundsThunkBuilder : BoundsThunkBuilder {
494577 var index : Int { get }
578+ var nonescaping : Bool { get }
495579}
496580
497- extension ParamPointerBoundsThunkBuilder {
498- var generateSpan : Bool { nonescaping }
499-
581+ extension ParamBoundsThunkBuilder {
500582 var param : FunctionParameterSyntax {
501583 return getParam ( signature, index)
502584 }
@@ -518,7 +600,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
518600 public let isSizedBy : Bool
519601 public let dependencies : [ LifetimeDependence ]
520602
521- var generateSpan : Bool { !dependencies. isEmpty }
603+ var generateSpan : Bool { !dependencies. isEmpty && ( !isMutablePointerType ( oldType ) || enableMutableSpan ) }
522604
523605 var oldType : TypeSyntax {
524606 return signature. returnClause!. type
@@ -531,7 +613,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
531613 }
532614
533615 func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
534- return [ ]
616+ return try base . buildBoundsChecks ( )
535617 }
536618
537619 func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
@@ -548,7 +630,8 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
548630 }
549631}
550632
551- struct CountedOrSizedPointerThunkBuilder : ParamPointerBoundsThunkBuilder {
633+
634+ struct CountedOrSizedPointerThunkBuilder : ParamBoundsThunkBuilder , PointerBoundsThunkBuilder {
552635 public let base : BoundsCheckedThunkBuilder
553636 public let index : Int
554637 public let countExpr : ExprSyntax
@@ -557,6 +640,8 @@ struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {
557640 public let isSizedBy : Bool
558641 public let skipTrivialCount : Bool
559642
643+ var generateSpan : Bool { nonescaping && ( !isMutablePointerType( oldType) || enableMutableSpan) }
644+
560645 func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
561646 -> ( FunctionSignatureSyntax , Bool ) {
562647 var types = argTypes
0 commit comments