@@ -3313,20 +3313,298 @@ trait Reflection { reflection =>
33133313 // UTILS //
33143314 // /////////////
33153315
3316- /** TASTy Reflect tree accumulator */
3317- trait TreeAccumulator [X ] extends reflect.TreeAccumulator [X ] {
3318- val reflect : reflection.type = reflection
3319- }
3316+ /** TASTy Reflect tree accumulator.
3317+ *
3318+ * Usage:
3319+ * ```
3320+ * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3321+ * extends scala.tasty.reflect.TreeAccumulator[X] {
3322+ * import reflect._
3323+ * def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
3324+ * }
3325+ * ```
3326+ */
3327+ trait TreeAccumulator [X ]:
3328+
3329+ // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
3330+ def foldTree (x : X , tree : Tree )(using ctx : Context ): X
3331+
3332+ def foldTrees (x : X , trees : Iterable [Tree ])(using ctx : Context ): X = trees.foldLeft(x)(foldTree)
3333+
3334+ def foldOverTree (x : X , tree : Tree )(using ctx : Context ): X = {
3335+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3336+ tree match {
3337+ case Ident (_) =>
3338+ x
3339+ case Select (qualifier, _) =>
3340+ foldTree(x, qualifier)
3341+ case This (qual) =>
3342+ x
3343+ case Super (qual, _) =>
3344+ foldTree(x, qual)
3345+ case Apply (fun, args) =>
3346+ foldTrees(foldTree(x, fun), args)
3347+ case TypeApply (fun, args) =>
3348+ foldTrees(foldTree(x, fun), args)
3349+ case Literal (const) =>
3350+ x
3351+ case New (tpt) =>
3352+ foldTree(x, tpt)
3353+ case Typed (expr, tpt) =>
3354+ foldTree(foldTree(x, expr), tpt)
3355+ case NamedArg (_, arg) =>
3356+ foldTree(x, arg)
3357+ case Assign (lhs, rhs) =>
3358+ foldTree(foldTree(x, lhs), rhs)
3359+ case Block (stats, expr) =>
3360+ foldTree(foldTrees(x, stats), expr)
3361+ case If (cond, thenp, elsep) =>
3362+ foldTree(foldTree(foldTree(x, cond), thenp), elsep)
3363+ case While (cond, body) =>
3364+ foldTree(foldTree(x, cond), body)
3365+ case Closure (meth, tpt) =>
3366+ foldTree(x, meth)
3367+ case Match (selector, cases) =>
3368+ foldTrees(foldTree(x, selector), cases)
3369+ case Return (expr, _) =>
3370+ foldTree(x, expr)
3371+ case Try (block, handler, finalizer) =>
3372+ foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
3373+ case Repeated (elems, elemtpt) =>
3374+ foldTrees(foldTree(x, elemtpt), elems)
3375+ case Inlined (call, bindings, expansion) =>
3376+ foldTree(foldTrees(x, bindings), expansion)
3377+ case vdef @ ValDef (_, tpt, rhs) =>
3378+ val ctx = localCtx(vdef)
3379+ given Context = ctx
3380+ foldTrees(foldTree(x, tpt), rhs)
3381+ case ddef @ DefDef (_, tparams, vparamss, tpt, rhs) =>
3382+ val ctx = localCtx(ddef)
3383+ given Context = ctx
3384+ foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
3385+ case tdef @ TypeDef (_, rhs) =>
3386+ val ctx = localCtx(tdef)
3387+ given Context = ctx
3388+ foldTree(x, rhs)
3389+ case cdef @ ClassDef (_, constr, parents, derived, self, body) =>
3390+ val ctx = localCtx(cdef)
3391+ given Context = ctx
3392+ foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
3393+ case Import (expr, _) =>
3394+ foldTree(x, expr)
3395+ case clause @ PackageClause (pid, stats) =>
3396+ foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
3397+ case Inferred () => x
3398+ case TypeIdent (_) => x
3399+ case TypeSelect (qualifier, _) => foldTree(x, qualifier)
3400+ case Projection (qualifier, _) => foldTree(x, qualifier)
3401+ case Singleton (ref) => foldTree(x, ref)
3402+ case Refined (tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
3403+ case Applied (tpt, args) => foldTrees(foldTree(x, tpt), args)
3404+ case ByName (result) => foldTree(x, result)
3405+ case Annotated (arg, annot) => foldTree(foldTree(x, arg), annot)
3406+ case LambdaTypeTree (typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
3407+ case TypeBind (_, tbt) => foldTree(x, tbt)
3408+ case TypeBlock (typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
3409+ case MatchTypeTree (boundopt, selector, cases) =>
3410+ foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
3411+ case WildcardTypeTree () => x
3412+ case TypeBoundsTree (lo, hi) => foldTree(foldTree(x, lo), hi)
3413+ case CaseDef (pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
3414+ case TypeCaseDef (pat, body) => foldTree(foldTree(x, pat), body)
3415+ case Bind (_, body) => foldTree(x, body)
3416+ case Unapply (fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
3417+ case Alternatives (patterns) => foldTrees(x, patterns)
3418+ }
3419+ }
3420+ end TreeAccumulator
33203421
3321- /** TASTy Reflect tree traverser */
3322- trait TreeTraverser extends reflect.TreeTraverser {
3323- val reflect : reflection.type = reflection
3324- }
33253422
3326- /** TASTy Reflect tree map */
3327- trait TreeMap extends reflect.TreeMap {
3328- val reflect : reflection.type = reflection
3329- }
3423+ /** TASTy Reflect tree traverser.
3424+ *
3425+ * Usage:
3426+ * ```
3427+ * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
3428+ * extends scala.tasty.reflect.TreeTraverser {
3429+ * import reflect._
3430+ * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
3431+ * }
3432+ * ```
3433+ */
3434+ trait TreeTraverser extends TreeAccumulator [Unit ]:
3435+
3436+ def traverseTree (tree : Tree )(using ctx : Context ): Unit = traverseTreeChildren(tree)
3437+
3438+ def foldTree (x : Unit , tree : Tree )(using ctx : Context ): Unit = traverseTree(tree)
3439+
3440+ protected def traverseTreeChildren (tree : Tree )(using ctx : Context ): Unit = foldOverTree((), tree)
3441+
3442+ end TreeTraverser
3443+
3444+ /** TASTy Reflect tree map.
3445+ *
3446+ * Usage:
3447+ * ```
3448+ * import qctx.reflect._
3449+ * class MyTreeMap extends TreeMap {
3450+ * override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
3451+ * }
3452+ * ```
3453+ */
3454+ trait TreeMap :
3455+
3456+ def transformTree (tree : Tree )(using ctx : Context ): Tree = {
3457+ tree match {
3458+ case tree : PackageClause =>
3459+ PackageClause .copy(tree)(transformTerm(tree.pid).asInstanceOf [Ref ], transformTrees(tree.stats)(using tree.symbol.localContext))
3460+ case tree : Import =>
3461+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3462+ case tree : Statement =>
3463+ transformStatement(tree)
3464+ case tree : TypeTree => transformTypeTree(tree)
3465+ case tree : TypeBoundsTree => tree // TODO traverse tree
3466+ case tree : WildcardTypeTree => tree // TODO traverse tree
3467+ case tree : CaseDef =>
3468+ transformCaseDef(tree)
3469+ case tree : TypeCaseDef =>
3470+ transformTypeCaseDef(tree)
3471+ case pattern : Bind =>
3472+ Bind .copy(pattern)(pattern.name, pattern.pattern)
3473+ case pattern : Unapply =>
3474+ Unapply .copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
3475+ case pattern : Alternatives =>
3476+ Alternatives .copy(pattern)(transformTrees(pattern.patterns))
3477+ }
3478+ }
3479+
3480+ def transformStatement (tree : Statement )(using ctx : Context ): Statement = {
3481+ def localCtx (definition : Definition ): Context = definition.symbol.localContext
3482+ tree match {
3483+ case tree : Term =>
3484+ transformTerm(tree)
3485+ case tree : ValDef =>
3486+ val ctx = localCtx(tree)
3487+ given Context = ctx
3488+ val tpt1 = transformTypeTree(tree.tpt)
3489+ val rhs1 = tree.rhs.map(x => transformTerm(x))
3490+ ValDef .copy(tree)(tree.name, tpt1, rhs1)
3491+ case tree : DefDef =>
3492+ val ctx = localCtx(tree)
3493+ given Context = ctx
3494+ DefDef .copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
3495+ case tree : TypeDef =>
3496+ val ctx = localCtx(tree)
3497+ given Context = ctx
3498+ TypeDef .copy(tree)(tree.name, transformTree(tree.rhs))
3499+ case tree : ClassDef =>
3500+ ClassDef .copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
3501+ case tree : Import =>
3502+ Import .copy(tree)(transformTerm(tree.expr), tree.selectors)
3503+ }
3504+ }
3505+
3506+ def transformTerm (tree : Term )(using ctx : Context ): Term = {
3507+ tree match {
3508+ case Ident (name) =>
3509+ tree
3510+ case Select (qualifier, name) =>
3511+ Select .copy(tree)(transformTerm(qualifier), name)
3512+ case This (qual) =>
3513+ tree
3514+ case Super (qual, mix) =>
3515+ Super .copy(tree)(transformTerm(qual), mix)
3516+ case Apply (fun, args) =>
3517+ Apply .copy(tree)(transformTerm(fun), transformTerms(args))
3518+ case TypeApply (fun, args) =>
3519+ TypeApply .copy(tree)(transformTerm(fun), transformTypeTrees(args))
3520+ case Literal (const) =>
3521+ tree
3522+ case New (tpt) =>
3523+ New .copy(tree)(transformTypeTree(tpt))
3524+ case Typed (expr, tpt) =>
3525+ Typed .copy(tree)(transformTerm(expr), transformTypeTree(tpt))
3526+ case tree : NamedArg =>
3527+ NamedArg .copy(tree)(tree.name, transformTerm(tree.value))
3528+ case Assign (lhs, rhs) =>
3529+ Assign .copy(tree)(transformTerm(lhs), transformTerm(rhs))
3530+ case Block (stats, expr) =>
3531+ Block .copy(tree)(transformStats(stats), transformTerm(expr))
3532+ case If (cond, thenp, elsep) =>
3533+ If .copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
3534+ case Closure (meth, tpt) =>
3535+ Closure .copy(tree)(transformTerm(meth), tpt)
3536+ case Match (selector, cases) =>
3537+ Match .copy(tree)(transformTerm(selector), transformCaseDefs(cases))
3538+ case Return (expr, from) =>
3539+ Return .copy(tree)(transformTerm(expr), from)
3540+ case While (cond, body) =>
3541+ While .copy(tree)(transformTerm(cond), transformTerm(body))
3542+ case Try (block, cases, finalizer) =>
3543+ Try .copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
3544+ case Repeated (elems, elemtpt) =>
3545+ Repeated .copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
3546+ case Inlined (call, bindings, expansion) =>
3547+ Inlined .copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/* ()call.symbol.localContext)*/ )
3548+ }
3549+ }
3550+
3551+ def transformTypeTree (tree : TypeTree )(using ctx : Context ): TypeTree = tree match {
3552+ case Inferred () => tree
3553+ case tree : TypeIdent => tree
3554+ case tree : TypeSelect =>
3555+ TypeSelect .copy(tree)(tree.qualifier, tree.name)
3556+ case tree : Projection =>
3557+ Projection .copy(tree)(tree.qualifier, tree.name)
3558+ case tree : Annotated =>
3559+ Annotated .copy(tree)(tree.arg, tree.annotation)
3560+ case tree : Singleton =>
3561+ Singleton .copy(tree)(transformTerm(tree.ref))
3562+ case tree : Refined =>
3563+ Refined .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf [List [Definition ]])
3564+ case tree : Applied =>
3565+ Applied .copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
3566+ case tree : MatchTypeTree =>
3567+ MatchTypeTree .copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
3568+ case tree : ByName =>
3569+ ByName .copy(tree)(transformTypeTree(tree.result))
3570+ case tree : LambdaTypeTree =>
3571+ LambdaTypeTree .copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
3572+ case tree : TypeBind =>
3573+ TypeBind .copy(tree)(tree.name, tree.body)
3574+ case tree : TypeBlock =>
3575+ TypeBlock .copy(tree)(tree.aliases, tree.tpt)
3576+ }
3577+
3578+ def transformCaseDef (tree : CaseDef )(using ctx : Context ): CaseDef = {
3579+ CaseDef .copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
3580+ }
3581+
3582+ def transformTypeCaseDef (tree : TypeCaseDef )(using ctx : Context ): TypeCaseDef = {
3583+ TypeCaseDef .copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
3584+ }
3585+
3586+ def transformStats (trees : List [Statement ])(using ctx : Context ): List [Statement ] =
3587+ trees mapConserve (transformStatement(_))
3588+
3589+ def transformTrees (trees : List [Tree ])(using ctx : Context ): List [Tree ] =
3590+ trees mapConserve (transformTree(_))
3591+
3592+ def transformTerms (trees : List [Term ])(using ctx : Context ): List [Term ] =
3593+ trees mapConserve (transformTerm(_))
3594+
3595+ def transformTypeTrees (trees : List [TypeTree ])(using ctx : Context ): List [TypeTree ] =
3596+ trees mapConserve (transformTypeTree(_))
3597+
3598+ def transformCaseDefs (trees : List [CaseDef ])(using ctx : Context ): List [CaseDef ] =
3599+ trees mapConserve (transformCaseDef(_))
3600+
3601+ def transformTypeCaseDefs (trees : List [TypeCaseDef ])(using ctx : Context ): List [TypeCaseDef ] =
3602+ trees mapConserve (transformTypeCaseDef(_))
3603+
3604+ def transformSubTrees [Tr <: Tree ](trees : List [Tr ])(using ctx : Context ): List [Tr ] =
3605+ transformTrees(trees).asInstanceOf [List [Tr ]]
3606+
3607+ end TreeMap
33303608
33313609 // TODO: extract from Reflection
33323610
0 commit comments