@@ -91,6 +91,12 @@ func removeParentheses(from expr: ExprSyntax) -> ExprSyntax? {
9191
9292// MARK: - Inserting expression context callouts
9393
94+ /// The maximum value of `_rewriteDepth` allowed by `_rewrite()` before it will
95+ /// start bailing early.
96+ private let _maximumRewriteDepth = {
97+ Int . max // disable rewrite-limiting (need to evaluate possible heuristics)
98+ } ( )
99+
94100/// A type that inserts calls to an `__ExpectationContext` instance into an
95101/// expression's syntax tree.
96102private final class _ContextInserter < C, M> : SyntaxRewriter where C: MacroExpansionContext , M: FreestandingMacroExpansionSyntax {
@@ -123,6 +129,12 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
123129 super. init ( )
124130 }
125131
132+ /// The number of calls to `_rewrite()` made along the current node hierarchy.
133+ ///
134+ /// This value is incremented with each call to `_rewrite()` and managed by
135+ /// `_visitChild()`.
136+ private var _rewriteDepth = 0
137+
126138 /// Rewrite a given syntax node by inserting a call to the expression context
127139 /// (or rather, its `callAsFunction(_:_:)` member).
128140 ///
@@ -137,14 +149,27 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
137149 ///
138150 /// - Returns: A rewritten copy of `node` that calls into the expression
139151 /// context when it is evaluated at runtime.
140- private func _rewrite( _ node: some ExprSyntaxProtocol , originalWas originalNode: some ExprSyntaxProtocol , calling functionName: TokenSyntax ? = nil , passing additionalArguments: [ Argument ] = [ ] ) -> ExprSyntax {
141- guard rewrittenNodes. insert ( Syntax ( originalNode) ) . inserted else {
142- // If this node has already been rewritten, we don't need to rewrite it
143- // again. (Currently, this can only happen when expanding binary operators
144- // which need a bit of extra help.)
145- return ExprSyntax ( node)
152+ private func _rewrite(
153+ _ node: @autoclosure ( ) -> some ExprSyntaxProtocol ,
154+ originalWas originalNode: @autoclosure ( ) -> some ExprSyntaxProtocol ,
155+ calling functionName: @autoclosure ( ) -> TokenSyntax ? = nil ,
156+ passing additionalArguments: @autoclosure ( ) -> [ Argument ] = [ ]
157+ ) -> ExprSyntax {
158+ _rewriteDepth += 1
159+ if _rewriteDepth > _maximumRewriteDepth {
160+ // At least 2 ancestors of this node have already been rewritten, so do
161+ // not recursively rewrite further. This is necessary to limit the added
162+ // exponentional complexity we're throwing at the type checker.
163+ return ExprSyntax ( originalNode ( ) )
146164 }
147165
166+ // We're going to rewrite the node, so we'll evaluate the arguments now.
167+ let node = node ( )
168+ let originalNode = originalNode ( )
169+ let functionName = functionName ( )
170+ let additionalArguments = additionalArguments ( )
171+ rewrittenNodes. insert ( Syntax ( originalNode) )
172+
148173 let calledExpr : ExprSyntax = if let functionName {
149174 ExprSyntax ( MemberAccessExprSyntax ( base: expressionContextNameExpr, name: functionName) )
150175 } else {
@@ -200,6 +225,43 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
200225 _rewrite ( node, originalWas: node, calling: functionName, passing: additionalArguments)
201226 }
202227
228+ /// Visit `node` as a child of another previously-visited node.
229+ ///
230+ /// - Parameters:
231+ /// - node: The node to visit.
232+ ///
233+ /// - Returns: `node`, or a modified copy thereof if `node` or a child node
234+ /// was rewritten.
235+ ///
236+ /// Use this function instead of calling `visit(_:)` or `rewrite(_:detach:)`
237+ /// recursively.
238+ ///
239+ /// This overload simply visits `node` and is used for nodes that cannot be
240+ /// rewritten directly (because they are not expressions.)
241+ @_disfavoredOverload
242+ private func _visitChild< S> ( _ node: S ) -> S where S: SyntaxProtocol {
243+ rewrite ( node, detach: true ) . cast ( S . self)
244+ }
245+
246+ /// Visit `node` as a child of another previously-visited node.
247+ ///
248+ /// - Parameters:
249+ /// - node: The node to visit.
250+ ///
251+ /// - Returns: `node`, or a modified copy thereof if `node` or a child node
252+ /// was rewritten.
253+ ///
254+ /// Use this function instead of calling `visit(_:)` or `rewrite(_:detach:)`
255+ /// recursively.
256+ private func _visitChild( _ node: some ExprSyntaxProtocol ) -> ExprSyntax {
257+ let oldRewriteDepth = _rewriteDepth
258+ defer {
259+ _rewriteDepth = oldRewriteDepth
260+ }
261+
262+ return rewrite ( node, detach: true ) . cast ( ExprSyntax . self)
263+ }
264+
203265 /// Whether or not the parent node of the given node is capable of containing
204266 /// a rewritten `DeclReferenceExprSyntax` instance.
205267 ///
@@ -281,7 +343,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
281343 return _rewrite (
282344 TupleExprSyntax {
283345 for element in node. elements {
284- visit ( element) . trimmed
346+ _visitChild ( element) . trimmed
285347 }
286348 } ,
287349 originalWas: node
@@ -302,28 +364,28 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
302364 // expressions can be directly extracted out.
303365 if _isParentOfDeclReferenceExprValidForRewriting ( node) {
304366 return _rewrite (
305- node. with ( \. base, node. base. map ( visit ) ) ,
367+ node. with ( \. base, node. base. map ( _visitChild ) ) ,
306368 originalWas: node
307369 )
308370 }
309371
310- return ExprSyntax ( node. with ( \. base, node. base. map ( visit ) ) )
372+ return ExprSyntax ( node. with ( \. base, node. base. map ( _visitChild ) ) )
311373 }
312374
313375 override func visit( _ node: FunctionCallExprSyntax ) -> ExprSyntax {
314376 _rewrite (
315377 node
316- . with ( \. calledExpression, visit ( node. calledExpression) )
317- . with ( \. arguments, visit ( node. arguments) ) ,
378+ . with ( \. calledExpression, _visitChild ( node. calledExpression) )
379+ . with ( \. arguments, _visitChild ( node. arguments) ) ,
318380 originalWas: node
319381 )
320382 }
321383
322384 override func visit( _ node: SubscriptCallExprSyntax ) -> ExprSyntax {
323385 _rewrite (
324386 node
325- . with ( \. calledExpression, visit ( node. calledExpression) )
326- . with ( \. arguments, visit ( node. arguments) ) ,
387+ . with ( \. calledExpression, _visitChild ( node. calledExpression) )
388+ . with ( \. arguments, _visitChild ( node. arguments) ) ,
327389 originalWas: node
328390 )
329391 }
@@ -355,7 +417,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
355417
356418 return _rewrite (
357419 node
358- . with ( \. expression, visit ( node. expression) ) ,
420+ . with ( \. expression, _visitChild ( node. expression) ) ,
359421 originalWas: node
360422 )
361423 }
@@ -377,18 +439,18 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
377439 originalWas: node,
378440 calling: . identifier( " __cmp " ) ,
379441 passing: [
380- Argument ( expression: visit ( node. leftOperand) ) ,
442+ Argument ( expression: _visitChild ( node. leftOperand) ) ,
381443 Argument ( expression: node. leftOperand. expressionID ( rootedAt: effectiveRootNode) ) ,
382- Argument ( expression: visit ( node. rightOperand) ) ,
444+ Argument ( expression: _visitChild ( node. rightOperand) ) ,
383445 Argument ( expression: node. rightOperand. expressionID ( rootedAt: effectiveRootNode) )
384446 ]
385447 )
386448 }
387449
388450 return _rewrite (
389451 node
390- . with ( \. leftOperand, visit ( node. leftOperand) )
391- . with ( \. rightOperand, visit ( node. rightOperand) ) ,
452+ . with ( \. leftOperand, _visitChild ( node. leftOperand) )
453+ . with ( \. rightOperand, _visitChild ( node. rightOperand) ) ,
392454 originalWas: node
393455 )
394456 }
@@ -399,12 +461,11 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
399461 // `inout`, so it should be sufficient to capture it in a `defer` statement
400462 // that runs after the expression is evaluated.
401463
402- let teardownItem = CodeBlockItemSyntax (
403- item: . expr(
404- _rewrite ( node. expression, calling: . identifier( " __inoutAfter " ) )
405- )
406- )
407- teardownItems. append ( teardownItem)
464+ let rewrittenExpr = _rewrite ( node. expression, calling: . identifier( " __inoutAfter " ) )
465+ if rewrittenExpr != ExprSyntax ( node. expression) {
466+ let teardownItem = CodeBlockItemSyntax ( item: . expr( rewrittenExpr) )
467+ teardownItems. append ( teardownItem)
468+ }
408469
409470 // The argument should not be expanded in-place as we can't return an
410471 // argument passed `inout` and expect it to remain semantically correct.
@@ -427,7 +488,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
427488 rewrittenNodes. insert ( Syntax ( type) )
428489
429490 return _rewrite (
430- visit ( valueExpr) . trimmed,
491+ _visitChild ( valueExpr) . trimmed,
431492 originalWas: originalNode,
432493 calling: . identifier( " __ \( isAsKeyword) " ) ,
433494 passing: [
@@ -503,7 +564,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
503564 node. with (
504565 \. elements, ArrayElementListSyntax {
505566 for element in node. elements {
506- ArrayElementSyntax ( expression: visit ( element. expression) . trimmed)
567+ ArrayElementSyntax ( expression: _visitChild ( element. expression) . trimmed)
507568 }
508569 }
509570 ) ,
@@ -520,7 +581,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
520581 \. content, . elements(
521582 DictionaryElementListSyntax {
522583 for element in elements {
523- DictionaryElementSyntax ( key: visit ( element. key) . trimmed, value: visit ( element. value) . trimmed)
584+ DictionaryElementSyntax ( key: _visitChild ( element. key) . trimmed, value: _visitChild ( element. value) . trimmed)
524585 }
525586 }
526587 )
@@ -570,7 +631,7 @@ extension ConditionMacro {
570631 _diagnoseTrivialBooleanValue ( from: ExprSyntax ( node) , for: macro, in: context)
571632
572633 let contextInserter = _ContextInserter ( in: context, for: macro, rootedAt: Syntax ( effectiveRootNode) , expressionContextName: expressionContextName)
573- var expandedExpr = contextInserter. rewrite ( node) . cast ( ExprSyntax . self)
634+ var expandedExpr = contextInserter. rewrite ( node, detach : true ) . cast ( ExprSyntax . self)
574635 let rewrittenNodes = contextInserter. rewrittenNodes
575636
576637 // Insert additional effect keywords/thunks as needed.
@@ -606,7 +667,7 @@ extension ConditionMacro {
606667 var captureList : ClosureCaptureClauseSyntax ?
607668 do {
608669 let dollarIDReplacer = _DollarIdentifierReplacer ( )
609- codeBlockItems = dollarIDReplacer. rewrite ( codeBlockItems) . cast ( CodeBlockItemListSyntax . self)
670+ codeBlockItems = dollarIDReplacer. rewrite ( codeBlockItems, detach : true ) . cast ( CodeBlockItemListSyntax . self)
610671 if !dollarIDReplacer. dollarIdentifierTokenKinds. isEmpty {
611672 let dollarIdentifierTokens = dollarIDReplacer. dollarIdentifierTokenKinds. map { tokenKind in
612673 TokenSyntax ( tokenKind, presence: . present)
0 commit comments