Skip to content

Commit 9dbb13a

Browse files
Support open accesslevel
1 parent 25b3800 commit 9dbb13a

File tree

5 files changed

+49
-14
lines changed

5 files changed

+49
-14
lines changed

Sources/Spyable/Spyable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
public macro Spyable(
138138
behindPreprocessorFlag: String? = nil,
139139
accessLevel: SpyAccessLevel? = nil,
140-
inheritedTypes: String? = nil
140+
inheritedType: String? = nil
141141
) =
142142
#externalMacro(
143143
module: "SpyableMacro",

Sources/SpyableMacro/Extractors/Extractor.swift

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ struct Extractor {
106106
let accessLevelText = memberAccess.declName.baseName.text
107107

108108
switch accessLevelText {
109+
case "open":
110+
return DeclModifierSyntax(name: .keyword(.open))
111+
109112
case "public":
110113
return DeclModifierSyntax(name: .keyword(.public))
111114

@@ -146,7 +149,7 @@ struct Extractor {
146149
protocolDeclSyntax.modifiers.first(where: \.name.isAccessLevelSupportedInProtocol)
147150
}
148151

149-
func extractInheritedTypes(
152+
func extractInheritedType(
150153
from attribute: AttributeSyntax,
151154
in context: some MacroExpansionContext
152155
) -> String? {
@@ -155,29 +158,30 @@ struct Extractor {
155158
return nil
156159
}
157160

158-
let inheritedTypesArgument = argumentList.first { argument in
159-
argument.label?.text == "inheritedTypes"
161+
let inheritedTypeArgument = argumentList.first { argument in
162+
argument.label?.text == "inheritedType"
160163
}
161164

162-
guard let inheritedTypesArgument else {
163-
// The `inheritedTypes` argument is missing.
165+
guard let inheritedTypeArgument else {
166+
// The `inheritedType` argument is missing.
164167
return nil
165168
}
166169

167-
let segments = inheritedTypesArgument.expression
170+
// Check if it's a string literal expression
171+
let segments = inheritedTypeArgument.expression
168172
.as(StringLiteralExprSyntax.self)?
169173
.segments
170174

171175
guard let segments,
172176
segments.count == 1,
173177
case let .stringSegment(literalSegment)? = segments.first
174178
else {
175-
// The `inheritedTypes` argument's value is not a static string literal.
179+
// The `inheritedType` argument's value is not a valid string literal.
176180
context.diagnose(
177181
Diagnostic(
178182
node: attribute,
179183
message: SpyableDiagnostic.behindPreprocessorFlagArgumentRequiresStaticStringLiteral,
180-
highlights: [Syntax(inheritedTypesArgument.expression)]
184+
highlights: [Syntax(inheritedTypeArgument.expression)]
181185
)
182186
)
183187
return nil

Sources/SpyableMacro/Factories/SpyFactory.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ struct SpyFactory {
9292
private let closureFactory = ClosureFactory()
9393
private let functionImplementationFactory = FunctionImplementationFactory()
9494

95-
func classDeclaration(for protocolDeclaration: ProtocolDeclSyntax) throws -> ClassDeclSyntax {
95+
func classDeclaration(
96+
for protocolDeclaration: ProtocolDeclSyntax,
97+
inheritedType: String? = nil
98+
) throws -> ClassDeclSyntax {
9699
let identifier = TokenSyntax.identifier(protocolDeclaration.name.text + "Spy")
97100

98101
let assosciatedtypeDeclarations = protocolDeclaration.memberBlock.members.compactMap {
@@ -117,6 +120,14 @@ struct SpyFactory {
117120
name: identifier,
118121
genericParameterClause: genericParameterClause,
119122
inheritanceClause: InheritanceClauseSyntax {
123+
// Add inherited type first if present
124+
if let inheritedType {
125+
InheritedTypeSyntax(
126+
type: TypeSyntax(stringLiteral: inheritedType)
127+
)
128+
}
129+
130+
// Add the main protocol
120131
InheritedTypeSyntax(
121132
type: TypeSyntax(stringLiteral: protocolDeclaration.name.text)
122133
)
@@ -125,7 +136,10 @@ struct SpyFactory {
125136
)
126137
},
127138
memberBlockBuilder: {
139+
let initOverrideKeyword: DeclModifierListSyntax = inheritedType != nil ? [DeclModifierSyntax(name: .keyword(.override))] : []
140+
128141
InitializerDeclSyntax(
142+
modifiers: initOverrideKeyword,
129143
signature: FunctionSignatureSyntax(
130144
parameterClause: FunctionParameterClauseSyntax(parameters: [])
131145
),

Sources/SpyableMacro/Macro/AccessLevelModifierRewriter.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,19 @@ final class AccessLevelModifierRewriter: SyntaxRewriter {
1717
return node
1818
}
1919

20-
return DeclModifierListSyntax {
21-
newAccessLevel
20+
// Always preserve existing modifiers (like override, convenience, etc.)
21+
var modifiers = Array(node)
22+
23+
// Special case: if accessLevel is open and this is an initializer, use public instead
24+
if newAccessLevel.name.text == TokenSyntax.keyword(.open).text,
25+
let parent = node.parent,
26+
parent.is(InitializerDeclSyntax.self) {
27+
modifiers.append(DeclModifierSyntax(name: .keyword(.public)))
28+
} else {
29+
// Add the access level modifier for all other cases
30+
modifiers.append(newAccessLevel)
2231
}
32+
33+
return DeclModifierListSyntax(modifiers)
2334
}
2435
}

Sources/SpyableMacro/Macro/SpyableMacro.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@ public enum SpyableMacro: PeerMacro {
1313
// Extract the protocol declaration
1414
let protocolDeclaration = try extractor.extractProtocolDeclaration(from: declaration)
1515

16-
// Generate the initial spy class declaration
17-
var spyClassDeclaration = try spyFactory.classDeclaration(for: protocolDeclaration)
16+
// Extract inherited type from the attribute
17+
let inheritedType = extractor.extractInheritedType(from: node, in: context)
18+
19+
// Generate the initial spy class declaration with inherited type
20+
var spyClassDeclaration = try spyFactory.classDeclaration(
21+
for: protocolDeclaration,
22+
inheritedType: inheritedType
23+
)
1824

1925
// Apply access level modifiers if needed
2026
if let accessLevel = determineAccessLevel(

0 commit comments

Comments
 (0)