@@ -3312,20 +3312,298 @@ trait Reflection { reflection =>
33123312 // UTILS //
33133313 // /////////////
33143314
3315- /** TASTy Reflect tree accumulator */
3316- trait TreeAccumulator [X ] extends reflect.TreeAccumulator [X ] {
3317- val reflect : reflection.type = reflection
3318- }
3315+ /** TASTy Reflect tree accumulator.
3316+ *
3317+ * Usage:
3318+ * ```
3319+ * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3320+ * extends scala.tasty.reflect.TreeAccumulator[X] {
3321+ * import reflect._
3322+ * def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
3323+ * }
3324+ * ```
3325+ */
3326+ trait TreeAccumulator [X ]:
3327+
3328+ // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
3329+ def foldTree (x : X , tree : Tree )(using ctx : Context ): X
3330+
3331+ def foldTrees (x : X , trees : Iterable [Tree ])(using ctx : Context ): X = trees.foldLeft(x)(foldTree)
3332+
3333+ def foldOverTree (x : X , tree : Tree )(using ctx : Context ): X = {
3334+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3335+ tree match {
3336+ case Ident (_) =>
3337+ x
3338+ case Select (qualifier, _) =>
3339+ foldTree(x, qualifier)
3340+ case This (qual) =>
3341+ x
3342+ case Super (qual, _) =>
3343+ foldTree(x, qual)
3344+ case Apply (fun, args) =>
3345+ foldTrees(foldTree(x, fun), args)
3346+ case TypeApply (fun, args) =>
3347+ foldTrees(foldTree(x, fun), args)
3348+ case Literal (const) =>
3349+ x
3350+ case New (tpt) =>
3351+ foldTree(x, tpt)
3352+ case Typed (expr, tpt) =>
3353+ foldTree(foldTree(x, expr), tpt)
3354+ case NamedArg (_, arg) =>
3355+ foldTree(x, arg)
3356+ case Assign (lhs, rhs) =>
3357+ foldTree(foldTree(x, lhs), rhs)
3358+ case Block (stats, expr) =>
3359+ foldTree(foldTrees(x, stats), expr)
3360+ case If (cond, thenp, elsep) =>
3361+ foldTree(foldTree(foldTree(x, cond), thenp), elsep)
3362+ case While (cond, body) =>
3363+ foldTree(foldTree(x, cond), body)
3364+ case Closure (meth, tpt) =>
3365+ foldTree(x, meth)
3366+ case Match (selector, cases) =>
3367+ foldTrees(foldTree(x, selector), cases)
3368+ case Return (expr, _) =>
3369+ foldTree(x, expr)
3370+ case Try (block, handler, finalizer) =>
3371+ foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
3372+ case Repeated (elems, elemtpt) =>
3373+ foldTrees(foldTree(x, elemtpt), elems)
3374+ case Inlined (call, bindings, expansion) =>
3375+ foldTree(foldTrees(x, bindings), expansion)
3376+ case vdef @ ValDef (_, tpt, rhs) =>
3377+ val ctx = localCtx(vdef)
3378+ given Context = ctx
3379+ foldTrees(foldTree(x, tpt), rhs)
3380+ case ddef @ DefDef (_, tparams, vparamss, tpt, rhs) =>
3381+ val ctx = localCtx(ddef)
3382+ given Context = ctx
3383+ foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
3384+ case tdef @ TypeDef (_, rhs) =>
3385+ val ctx = localCtx(tdef)
3386+ given Context = ctx
3387+ foldTree(x, rhs)
3388+ case cdef @ ClassDef (_, constr, parents, derived, self, body) =>
3389+ val ctx = localCtx(cdef)
3390+ given Context = ctx
3391+ foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
3392+ case Import (expr, _) =>
3393+ foldTree(x, expr)
3394+ case clause @ PackageClause (pid, stats) =>
3395+ foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
3396+ case Inferred () => x
3397+ case TypeIdent (_) => x
3398+ case TypeSelect (qualifier, _) => foldTree(x, qualifier)
3399+ case Projection (qualifier, _) => foldTree(x, qualifier)
3400+ case Singleton (ref) => foldTree(x, ref)
3401+ case Refined (tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
3402+ case Applied (tpt, args) => foldTrees(foldTree(x, tpt), args)
3403+ case ByName (result) => foldTree(x, result)
3404+ case Annotated (arg, annot) => foldTree(foldTree(x, arg), annot)
3405+ case LambdaTypeTree (typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
3406+ case TypeBind (_, tbt) => foldTree(x, tbt)
3407+ case TypeBlock (typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
3408+ case MatchTypeTree (boundopt, selector, cases) =>
3409+ foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
3410+ case WildcardTypeTree () => x
3411+ case TypeBoundsTree (lo, hi) => foldTree(foldTree(x, lo), hi)
3412+ case CaseDef (pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
3413+ case TypeCaseDef (pat, body) => foldTree(foldTree(x, pat), body)
3414+ case Bind (_, body) => foldTree(x, body)
3415+ case Unapply (fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
3416+ case Alternatives (patterns) => foldTrees(x, patterns)
3417+ }
3418+ }
3419+ end TreeAccumulator
33193420
3320- /** TASTy Reflect tree traverser */
3321- trait TreeTraverser extends reflect.TreeTraverser {
3322- val reflect : reflection.type = reflection
3323- }
33243421
3325- /** TASTy Reflect tree map */
3326- trait TreeMap extends reflect.TreeMap {
3327- val reflect : reflection.type = reflection
3328- }
3422+ /** TASTy Reflect tree traverser.
3423+ *
3424+ * Usage:
3425+ * ```
3426+ * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3427+ * extends scala.tasty.reflect.TreeTraverser {
3428+ * import reflect._
3429+ * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
3430+ * }
3431+ * ```
3432+ */
3433+ trait TreeTraverser extends TreeAccumulator [Unit ]:
3434+
3435+ def traverseTree (tree : Tree )(using ctx : Context ): Unit = traverseTreeChildren(tree)
3436+
3437+ def foldTree (x : Unit , tree : Tree )(using ctx : Context ): Unit = traverseTree(tree)
3438+
3439+ protected def traverseTreeChildren (tree : Tree )(using ctx : Context ): Unit = foldOverTree((), tree)
3440+
3441+ end TreeTraverser
3442+
3443+ /** TASTy Reflect tree map.
3444+ *
3445+ * Usage:
3446+ * ```
3447+ * import qctx.reflect._
3448+ * class MyTreeMap extends TreeMap {
3449+ * override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
3450+ * }
3451+ * ```
3452+ */
3453+ trait TreeMap :
3454+
3455+ def transformTree (tree : Tree )(using ctx : Context ): Tree = {
3456+ tree match {
3457+ case tree : PackageClause =>
3458+ PackageClause .copy(tree)(transformTerm(tree.pid).asInstanceOf [Ref ], transformTrees(tree.stats)(using tree.symbol.localContext))
3459+ case tree : Import =>
3460+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3461+ case tree : Statement =>
3462+ transformStatement(tree)
3463+ case tree : TypeTree => transformTypeTree(tree)
3464+ case tree : TypeBoundsTree => tree // TODO traverse tree
3465+ case tree : WildcardTypeTree => tree // TODO traverse tree
3466+ case tree : CaseDef =>
3467+ transformCaseDef(tree)
3468+ case tree : TypeCaseDef =>
3469+ transformTypeCaseDef(tree)
3470+ case pattern : Bind =>
3471+ Bind .copy(pattern)(pattern.name, pattern.pattern)
3472+ case pattern : Unapply =>
3473+ Unapply .copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
3474+ case pattern : Alternatives =>
3475+ Alternatives .copy(pattern)(transformTrees(pattern.patterns))
3476+ }
3477+ }
3478+
3479+ def transformStatement (tree : Statement )(using ctx : Context ): Statement = {
3480+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3481+ tree match {
3482+ case tree : Term =>
3483+ transformTerm(tree)
3484+ case tree : ValDef =>
3485+ val ctx = localCtx(tree)
3486+ given Context = ctx
3487+ val tpt1 = transformTypeTree(tree.tpt)
3488+ val rhs1 = tree.rhs.map(x => transformTerm(x))
3489+ ValDef .copy(tree)(tree.name, tpt1, rhs1)
3490+ case tree : DefDef =>
3491+ val ctx = localCtx(tree)
3492+ given Context = ctx
3493+ DefDef .copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
3494+ case tree : TypeDef =>
3495+ val ctx = localCtx(tree)
3496+ given Context = ctx
3497+ TypeDef .copy(tree)(tree.name, transformTree(tree.rhs))
3498+ case tree : ClassDef =>
3499+ ClassDef .copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
3500+ case tree : Import =>
3501+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3502+ }
3503+ }
3504+
3505+ def transformTerm (tree : Term )(using ctx : Context ): Term = {
3506+ tree match {
3507+ case Ident (name) =>
3508+ tree
3509+ case Select (qualifier, name) =>
3510+ Select .copy(tree)(transformTerm(qualifier), name)
3511+ case This (qual) =>
3512+ tree
3513+ case Super (qual, mix) =>
3514+ Super .copy(tree)(transformTerm(qual), mix)
3515+ case Apply (fun, args) =>
3516+ Apply .copy(tree)(transformTerm(fun), transformTerms(args))
3517+ case TypeApply (fun, args) =>
3518+ TypeApply .copy(tree)(transformTerm(fun), transformTypeTrees(args))
3519+ case Literal (const) =>
3520+ tree
3521+ case New (tpt) =>
3522+ New .copy(tree)(transformTypeTree(tpt))
3523+ case Typed (expr, tpt) =>
3524+ Typed .copy(tree)(transformTerm(expr), transformTypeTree(tpt))
3525+ case tree : NamedArg =>
3526+ NamedArg .copy(tree)(tree.name, transformTerm(tree.value))
3527+ case Assign (lhs, rhs) =>
3528+ Assign .copy(tree)(transformTerm(lhs), transformTerm(rhs))
3529+ case Block (stats, expr) =>
3530+ Block .copy(tree)(transformStats(stats), transformTerm(expr))
3531+ case If (cond, thenp, elsep) =>
3532+ If .copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
3533+ case Closure (meth, tpt) =>
3534+ Closure .copy(tree)(transformTerm(meth), tpt)
3535+ case Match (selector, cases) =>
3536+ Match .copy(tree)(transformTerm(selector), transformCaseDefs(cases))
3537+ case Return (expr, from) =>
3538+ Return .copy(tree)(transformTerm(expr), from)
3539+ case While (cond, body) =>
3540+ While .copy(tree)(transformTerm(cond), transformTerm(body))
3541+ case Try (block, cases, finalizer) =>
3542+ Try .copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
3543+ case Repeated (elems, elemtpt) =>
3544+ Repeated .copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
3545+ case Inlined (call, bindings, expansion) =>
3546+ Inlined .copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/* ()call.symbol.localContext)*/ )
3547+ }
3548+ }
3549+
3550+ def transformTypeTree (tree : TypeTree )(using ctx : Context ): TypeTree = tree match {
3551+ case Inferred () => tree
3552+ case tree : TypeIdent => tree
3553+ case tree : TypeSelect =>
3554+ TypeSelect .copy(tree)(tree.qualifier, tree.name)
3555+ case tree : Projection =>
3556+ Projection .copy(tree)(tree.qualifier, tree.name)
3557+ case tree : Annotated =>
3558+ Annotated .copy(tree)(tree.arg, tree.annotation)
3559+ case tree : Singleton =>
3560+ Singleton .copy(tree)(transformTerm(tree.ref))
3561+ case tree : Refined =>
3562+ Refined .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf [List [Definition ]])
3563+ case tree : Applied =>
3564+ Applied .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
3565+ case tree : MatchTypeTree =>
3566+ MatchTypeTree .copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
3567+ case tree : ByName =>
3568+ ByName .copy(tree)(transformTypeTree(tree.result))
3569+ case tree : LambdaTypeTree =>
3570+ LambdaTypeTree .copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
3571+ case tree : TypeBind =>
3572+ TypeBind .copy(tree)(tree.name, tree.body)
3573+ case tree : TypeBlock =>
3574+ TypeBlock .copy(tree)(tree.aliases, tree.tpt)
3575+ }
3576+
3577+ def transformCaseDef (tree : CaseDef )(using ctx : Context ): CaseDef = {
3578+ CaseDef .copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
3579+ }
3580+
3581+ def transformTypeCaseDef (tree : TypeCaseDef )(using ctx : Context ): TypeCaseDef = {
3582+ TypeCaseDef .copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
3583+ }
3584+
3585+ def transformStats (trees : List [Statement ])(using ctx : Context ): List [Statement ] =
3586+ trees mapConserve (transformStatement(_))
3587+
3588+ def transformTrees (trees : List [Tree ])(using ctx : Context ): List [Tree ] =
3589+ trees mapConserve (transformTree(_))
3590+
3591+ def transformTerms (trees : List [Term ])(using ctx : Context ): List [Term ] =
3592+ trees mapConserve (transformTerm(_))
3593+
3594+ def transformTypeTrees (trees : List [TypeTree ])(using ctx : Context ): List [TypeTree ] =
3595+ trees mapConserve (transformTypeTree(_))
3596+
3597+ def transformCaseDefs (trees : List [CaseDef ])(using ctx : Context ): List [CaseDef ] =
3598+ trees mapConserve (transformCaseDef(_))
3599+
3600+ def transformTypeCaseDefs (trees : List [TypeCaseDef ])(using ctx : Context ): List [TypeCaseDef ] =
3601+ trees mapConserve (transformTypeCaseDef(_))
3602+
3603+ def transformSubTrees [Tr <: Tree ](trees : List [Tr ])(using ctx : Context ): List [Tr ] =
3604+ transformTrees(trees).asInstanceOf [List [Tr ]]
3605+
3606+ end TreeMap
33293607
33303608 // TODO: extract from Reflection
33313609
0 commit comments