@@ -25,6 +25,7 @@ import ErrorReporting.errorTree
2525import dotty .tools .dotc .util .{SimpleIdentityMap , SimpleIdentitySet , EqHashMap , SourceFile , SourcePosition , SrcPos }
2626import dotty .tools .dotc .parsing .Parsers .Parser
2727import Nullables ._
28+ import transform .{PostTyper , Inlining }
2829
2930import collection .mutable
3031import reporting .trace
@@ -293,7 +294,7 @@ object Inliner {
293294 private enum ErrorKind :
294295 case Parser , Typer
295296
296- private def compileForErrors (tree : Tree , stopAfterParser : Boolean )(using Context ): List [(ErrorKind , Error )] =
297+ private def compileForErrors (tree : Tree )(using Context ): List [(ErrorKind , Error )] =
297298 assert(tree.symbol == defn.CompiletimeTesting_typeChecks || tree.symbol == defn.CompiletimeTesting_typeCheckErrors )
298299 def stripTyped (t : Tree ): Tree = t match {
299300 case Typed (t2, _) => stripTyped(t2)
@@ -311,17 +312,22 @@ object Inliner {
311312 ConstFold (underlyingCodeArg).tpe.widenTermRefExpr match {
312313 case ConstantType (Constant (code : String )) =>
313314 val source2 = SourceFile .virtual(" tasty-reflect" , code)
314- val ctx2 = ctx.fresh.setNewTyperState().setTyper(new Typer ).setSource(source2)
315- val tree2 = new Parser (source2)(using ctx2).block()
316- val res = collection.mutable.ListBuffer .empty[(ErrorKind , Error )]
317-
318- val parseErrors = ctx2.reporter.allErrors.toList
319- res ++= parseErrors.map(e => ErrorKind .Parser -> e)
320- if ! stopAfterParser || res.isEmpty then
321- ctx2.typer.typed(tree2)(using ctx2)
322- val typerErrors = ctx2.reporter.allErrors.filterNot(parseErrors.contains)
323- res ++= typerErrors.map(e => ErrorKind .Typer -> e)
324- res.toList
315+ inContext(ctx.fresh.setNewTyperState().setTyper(new Typer ).setSource(source2)) {
316+ val tree2 = new Parser (source2).block()
317+ if ctx.reporter.allErrors.nonEmpty then
318+ ctx.reporter.allErrors.map((ErrorKind .Parser , _))
319+ else
320+ val tree3 = ctx.typer.typed(tree2)
321+ ctx.base.postTyperPhase match
322+ case postTyper : PostTyper if ctx.reporter.allErrors.isEmpty =>
323+ val tree4 = atPhase(postTyper) { postTyper.newTransformer.transform(tree3) }
324+ ctx.base.inliningPhase match
325+ case inlining : Inlining if ctx.reporter.allErrors.isEmpty =>
326+ atPhase(inlining) { inlining.newTransformer.transform(tree4) }
327+ case _ =>
328+ case _ =>
329+ ctx.reporter.allErrors.map((ErrorKind .Typer , _))
330+ }
325331 case t =>
326332 report.error(em " argument to compileError must be a statically known String but was: $codeArg" , codeArg1.srcPos)
327333 Nil
@@ -346,12 +352,12 @@ object Inliner {
346352
347353 /** Expand call to scala.compiletime.testing.typeChecks */
348354 def typeChecks (tree : Tree )(using Context ): Tree =
349- val errors = compileForErrors(tree, true )
355+ val errors = compileForErrors(tree)
350356 Literal (Constant (errors.isEmpty)).withSpan(tree.span)
351357
352358 /** Expand call to scala.compiletime.testing.typeCheckErrors */
353359 def typeCheckErrors (tree : Tree )(using Context ): Tree =
354- val errors = compileForErrors(tree, false )
360+ val errors = compileForErrors(tree)
355361 packErrors(errors)
356362
357363 /** Expand call to scala.compiletime.codeOf */
0 commit comments