@@ -50,62 +50,177 @@ struct CapturedValueInfo {
5050 return
5151 }
5252
53- // Potentially get the name of the type comprising the current lexical
54- // context (i.e. whatever `Self` is.)
55- lazy var lexicalContext = context. lexicalContext
56- lazy var typeNameOfLexicalContext = {
57- let lexicalContext = lexicalContext. drop { !$0. isProtocol ( ( any DeclGroupSyntax ) . self) }
58- return context. type ( ofLexicalContext: lexicalContext)
59- } ( )
53+ if let ( expr, type) = Self . _inferExpressionAndType ( of: capture, in: context) {
54+ self . expression = expr
55+ self . type = type
56+ } else {
57+ // Not enough contextual information to derive the type here.
58+ context. diagnose ( . typeOfCaptureIsAmbiguous( capture) )
59+ }
60+ }
6061
62+ /// Infer the captured expression and the type of a closure capture list item.
63+ ///
64+ /// - Parameters:
65+ /// - capture: The closure capture list item to inspect.
66+ /// - context: The macro context in which the expression is being parsed.
67+ ///
68+ /// - Returns: A tuple containing the expression and type of `capture`, or
69+ /// `nil` if they could not be inferred.
70+ private static func _inferExpressionAndType( of capture: ClosureCaptureSyntax , in context: some MacroExpansionContext ) -> ( ExprSyntax , TypeSyntax ) ? {
6171 if let initializer = capture. initializer {
6272 // Found an initializer clause. Extract the expression it captures.
63- self . expression = removeParentheses ( from: initializer. value) ?? initializer. value
73+ let finder = _ExprTypeFinder ( in: context)
74+ finder. walk ( initializer. value)
75+ if let inferredType = finder. inferredType {
76+ return ( initializer. value, inferredType)
77+ }
78+ } else if capture. name. tokenKind == . keyword( . self ) ,
79+ let typeNameOfLexicalContext = Self . _inferSelf ( from: context) {
80+ // Capturing self.
81+ return ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: . keyword( . self ) ) ) , typeNameOfLexicalContext)
82+ } else if let parameterType = Self . _findTypeOfParameter ( named: capture. name, in: context. lexicalContext) {
83+ return ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: capture. name. trimmed) ) , parameterType)
84+ }
85+
86+ return nil
87+ }
88+
89+ private final class _ExprTypeFinder < C> : SyntaxAnyVisitor where C: MacroExpansionContext {
90+ var context : C
91+
92+ /// The type that was inferred from the visited syntax tree, if any.
93+ ///
94+ /// This type has not been fixed up yet. Use ``inferredType`` for the final
95+ /// derived type.
96+ private var _inferredType : TypeSyntax ?
97+
98+ /// Whether or not the inferred type has been made optional by e.g. `try?`.
99+ private var _needsOptionalApplied = false
100+
101+ /// The type that was inferred from the visited syntax tree, if any.
102+ var inferredType : TypeSyntax ? {
103+ _inferredType. flatMap { inferredType in
104+ if inferredType. isSome || inferredType. isAny {
105+ // `some` and `any` types are not concrete and cannot be inferred.
106+ nil
107+ } else if _needsOptionalApplied {
108+ TypeSyntax ( OptionalTypeSyntax ( wrappedType: inferredType. trimmed) )
109+ } else {
110+ inferredType
111+ }
112+ }
113+ }
114+
115+ init ( in context: C ) {
116+ self . context = context
117+ super. init ( viewMode: . sourceAccurate)
118+ }
119+
120+ override func visitAny( _ node: Syntax ) -> SyntaxVisitorContinueKind {
121+ if inferredType != nil {
122+ // Another part of the syntax tree has already provided a type. Stop.
123+ return . skipChildren
124+ }
64125
65- // Find the 'as' clause so we can determine the type of the captured value.
66- if let asExpr = self . expression. as ( AsExprSyntax . self) {
67- self . type = if asExpr. questionOrExclamationMark? . tokenKind == . postfixQuestionMark {
126+ switch node. kind {
127+ case . asExpr:
128+ let asExpr = node. cast ( AsExprSyntax . self)
129+ if let type = asExpr. type. as ( IdentifierTypeSyntax . self) , type. name. tokenKind == . keyword( . Self) {
130+ // `Self` should resolve to the lexical context's type.
131+ _inferredType = CapturedValueInfo . _inferSelf ( from: context)
132+ } else if asExpr. questionOrExclamationMark? . tokenKind == . postfixQuestionMark {
68133 // If the caller is using as?, make the type optional.
69- TypeSyntax ( OptionalTypeSyntax ( wrappedType: asExpr. type. trimmed) )
134+ _inferredType = TypeSyntax ( OptionalTypeSyntax ( wrappedType: asExpr. type. trimmed) )
70135 } else {
71- asExpr. type
136+ _inferredType = asExpr. type
72137 }
73- } else if let selfExpr = self . expression. as ( DeclReferenceExprSyntax . self) ,
74- selfExpr. baseName. tokenKind == . keyword( . self ) ,
75- selfExpr. argumentNames == nil ,
76- let typeNameOfLexicalContext {
77- // Copying self.
78- self . type = typeNameOfLexicalContext
79- } else {
80- // Handle literals. Any other types are ambiguous.
81- switch self . expression. kind {
82- case . integerLiteralExpr:
83- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " IntegerLiteralType " ) ) )
84- case . floatLiteralExpr:
85- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " FloatLiteralType " ) ) )
86- case . booleanLiteralExpr:
87- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " BooleanLiteralType " ) ) )
88- case . stringLiteralExpr, . simpleStringLiteralExpr:
89- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " StringLiteralType " ) ) )
90- default :
91- context. diagnose ( . typeOfCaptureIsAmbiguous( capture, initializedWith: initializer) )
138+ return . skipChildren
139+
140+ case . awaitExpr, . unsafeExpr:
141+ // These effect keywords do not affect the type of the expression.
142+ return . visitChildren
143+
144+ case . tryExpr:
145+ let tryExpr = node. cast ( TryExprSyntax . self)
146+ if tryExpr. questionOrExclamationMark? . tokenKind == . postfixQuestionMark {
147+ // The resulting type from the inner expression will be optionalized.
148+ _needsOptionalApplied = true
92149 }
93- }
150+ return . visitChildren
94151
95- } else if capture. name. tokenKind == . keyword( . self ) ,
96- let typeNameOfLexicalContext {
97- // Capturing self.
98- self . expression = " self "
99- self . type = typeNameOfLexicalContext
100- } else if let parameterType = Self . _findTypeOfParameter ( named: capture. name, in: lexicalContext) {
101- self . expression = ExprSyntax ( DeclReferenceExprSyntax ( baseName: capture. name. trimmed) )
102- self . type = parameterType
103- } else {
104- // Not enough contextual information to derive the type here.
105- context. diagnose ( . typeOfCaptureIsAmbiguous( capture) )
152+ case . tupleExpr:
153+ // If the tuple contains exactly one element, it's just parentheses
154+ // around that expression.
155+ let tupleExpr = node. cast ( TupleExprSyntax . self)
156+ if tupleExpr. elements. count == 1 {
157+ return . visitChildren
158+ }
159+
160+ // Otherwise, we need to try to compose the type as a tuple type from
161+ // the types of all elements in the tuple expression. Note that tuples
162+ // do not conform to Sendable or Codable, so our current use of this
163+ // code in exit tests will still diagnose an error, but the error ("must
164+ // conform") will be more useful than "couldn't infer".
165+ let elements = tupleExpr. elements. compactMap { element in
166+ let finder = Self ( in: context)
167+ finder. walk ( element. expression)
168+ return finder. inferredType. map { type in
169+ TupleTypeElementSyntax ( firstName: element. label? . trimmed, type: type. trimmed)
170+ }
171+ }
172+ if elements. count == tupleExpr. elements. count {
173+ _inferredType = TypeSyntax (
174+ TupleTypeSyntax ( elements: TupleTypeElementListSyntax { elements } )
175+ )
176+ }
177+ return . skipChildren
178+
179+ case . declReferenceExpr:
180+ // If the reference is to `self` without any arguments, its type can be
181+ // inferred from the lexical context.
182+ let expr = node. cast ( DeclReferenceExprSyntax . self)
183+ if expr. baseName. tokenKind == . keyword( . self ) , expr. argumentNames == nil {
184+ _inferredType = CapturedValueInfo . _inferSelf ( from: context)
185+ }
186+ return . skipChildren
187+
188+ case . integerLiteralExpr:
189+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " IntegerLiteralType " ) ) )
190+ return . skipChildren
191+
192+ case . floatLiteralExpr:
193+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " FloatLiteralType " ) ) )
194+ return . skipChildren
195+
196+ case . booleanLiteralExpr:
197+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " BooleanLiteralType " ) ) )
198+ return . skipChildren
199+
200+ case . stringLiteralExpr, . simpleStringLiteralExpr:
201+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " StringLiteralType " ) ) )
202+ return . skipChildren
203+
204+ default :
205+ // We don't know how to infer a type from this syntax node, so do not
206+ // proceed further.
207+ return . skipChildren
208+ }
106209 }
107210 }
108211
212+ /// Get the type of `self` inferred from the given context.
213+ ///
214+ /// - Parameters:
215+ /// - context: The macro context in which the expression is being parsed.
216+ ///
217+ /// - Returns: The type in `lexicalContext` corresponding to `Self`, or `nil`
218+ /// if it could not be determined.
219+ private static func _inferSelf( from context: some MacroExpansionContext ) -> TypeSyntax ? {
220+ let lexicalContext = context. lexicalContext. drop { !$0. isProtocol ( ( any DeclGroupSyntax ) . self) }
221+ return context. type ( ofLexicalContext: lexicalContext)
222+ }
223+
109224 /// Find a function or closure parameter in the given lexical context with a
110225 /// given name and return its type.
111226 ///
0 commit comments