Skip to content

Commit 300e81b

Browse files
committed
tree: extend Visitor to also visit TableExpr
I'm not quite sure the reason, but tree.Visitor does not visit TableExpr during a walk of an AST. For hint injection, we need to visit TableExpr as well as Expr. This commit extends tree.Visitor with new interface tree.ExtendedVisitor which visits both Expr and TableExpr during walks. Informs: #153633 Release note: None
1 parent 7a39b97 commit 300e81b

File tree

1 file changed

+151
-27
lines changed

1 file changed

+151
-27
lines changed

pkg/sql/sem/tree/walk.go

Lines changed: 151 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,60 @@ import (
1414
"github.com/cockroachdb/errors"
1515
)
1616

17-
// Visitor defines methods that are called for nodes during an expression or statement walk.
17+
// Visitor defines methods that are called for Expr nodes during an expression
18+
// or statement walk.
1819
type Visitor interface {
19-
// VisitPre is called for each node before recursing into that subtree. Upon return, if recurse
20-
// is false, the visit will not recurse into the subtree (and VisitPost will not be called for
21-
// this node).
20+
// VisitPre is called for each Expr node before recursing into that
21+
// subtree. Upon return, if recurse is false, the visit will not recurse into
22+
// the subtree (and VisitPost will not be called for this Expr node).
2223
//
23-
// The returned Expr replaces the visited expression and can be used for rewriting expressions.
24-
// The function should NOT modify nodes in-place; it should make copies of nodes. The Walk
25-
// infrastructure will automatically make copies of parents as needed.
24+
// The returned Expr replaces the visited expression and can be used for
25+
// rewriting expressions. The function should NOT modify nodes in-place; it
26+
// should make copies of nodes. The Walk infrastructure will automatically
27+
// make copies of parents as needed.
28+
//
29+
// VisitPre visits Exprs but not TableExprs. For TableExprs, VisitTablePre
30+
// must be used.
2631
VisitPre(expr Expr) (recurse bool, newExpr Expr)
2732

28-
// VisitPost is called for each node after recursing into the subtree. The returned Expr
29-
// replaces the visited expression and can be used for rewriting expressions.
33+
// VisitPost is called for each Expr node after recursing into the
34+
// subtree. The returned Expr replaces the visited expression and can be used
35+
// for rewriting expressions.
36+
//
37+
// The returned Expr replaces the visited expression and can be used for
38+
// rewriting expressions. The function should NOT modify nodes in-place; it
39+
// should make and return copies of nodes. The Walk infrastructure will
40+
// automatically make copies of parents as needed.
3041
//
31-
// The returned Expr replaces the visited expression and can be used for rewriting expressions.
32-
// The function should NOT modify nodes in-place; it should make and return copies of nodes. The
33-
// Walk infrastructure will automatically make copies of parents as needed.
42+
// VisitPost visits Exprs but not TableExprs. For TableExprs, VisitTablePost
43+
// must be used.
3444
VisitPost(expr Expr) (newNode Expr)
3545
}
3646

47+
// ExtendedVisitor extends Visitor with methods that are called for TableExpr
48+
// nodes during an expression or statement walk.
49+
//
50+
// Unlike Visitor, which does not visit some parts of the AST for historical
51+
// reasons, ExtendedVisitor is intended to visit every part of the tree.
52+
type ExtendedVisitor interface {
53+
Visitor
54+
55+
// VisitTablePre is called for each TableExpr node before recursing into that
56+
// subtree. Upon return, if recurse if false, the visit will not recurse into
57+
// the subtree (and VisitTablePost will node be called for this TableExpr
58+
// node).
59+
//
60+
// VisitTablePre is identical to VisitPre but handles TableExpr nodes.
61+
VisitTablePre(expr TableExpr) (recurse bool, newExpr TableExpr)
62+
63+
// VisitTablePost is called for each TableExpr node after recursing into the
64+
// subtree. The returned TableExpr replaces the visited expression and can be
65+
// used for rewriting expressions.
66+
//
67+
// VisitTablePost is identical to VisitPost but handles TableExpr nodes.
68+
VisitTablePost(expr TableExpr) (newNode TableExpr)
69+
}
70+
3771
// Walk implements the Expr interface.
3872
func (expr *AndExpr) Walk(v Visitor) Expr {
3973
left, changedL := WalkExpr(v, expr.Left)
@@ -311,7 +345,17 @@ func (expr *FuncExpr) Walk(v Visitor) Expr {
311345
ret.Filter = e
312346
}
313347
}
314-
348+
if _, ok := v.(ExtendedVisitor); ok {
349+
if expr.WindowDef != nil {
350+
w, changed := walkWindowDef(v, expr.WindowDef)
351+
if changed {
352+
if ret == expr {
353+
ret = expr.copyNode()
354+
}
355+
ret.WindowDef = w
356+
}
357+
}
358+
}
315359
if expr.OrderBy != nil {
316360
order, changed := walkOrderBy(v, expr.OrderBy)
317361
if changed {
@@ -553,10 +597,20 @@ func (expr *ParenTableExpr) WalkTableExpr(v Visitor) TableExpr {
553597
func (expr *JoinTableExpr) WalkTableExpr(v Visitor) TableExpr {
554598
left, changedL := walkTableExpr(v, expr.Left)
555599
right, changedR := walkTableExpr(v, expr.Right)
556-
if changedL || changedR {
600+
cond, changedCond := expr.Cond, false
601+
if _, ok := v.(ExtendedVisitor); ok {
602+
if j, ok := expr.Cond.(*OnJoinCond); ok {
603+
if onExpr, changed := WalkExpr(v, j.Expr); changed {
604+
cond = &OnJoinCond{onExpr}
605+
changedCond = true
606+
}
607+
}
608+
}
609+
if changedL || changedR || changedCond {
557610
exprCopy := *expr
558611
exprCopy.Left = left
559612
exprCopy.Right = right
613+
exprCopy.Cond = cond
560614
return &exprCopy
561615
}
562616
return expr
@@ -846,8 +900,16 @@ func WalkExpr(v Visitor, expr Expr) (newExpr Expr, changed bool) {
846900
return newExpr, (reflect.ValueOf(expr) != reflect.ValueOf(newExpr))
847901
}
848902

849-
func walkTableExpr(v Visitor, expr TableExpr) (newExpr TableExpr, changed bool) {
850-
newExpr = expr.WalkTableExpr(v)
903+
func walkTableExpr(v Visitor, expr TableExpr) (TableExpr, bool) {
904+
if ev, ok := v.(ExtendedVisitor); ok {
905+
recurse, newExpr := ev.VisitTablePre(expr)
906+
if recurse {
907+
newExpr = newExpr.WalkTableExpr(v)
908+
newExpr = ev.VisitTablePost(newExpr)
909+
}
910+
return newExpr, (reflect.ValueOf(expr) != reflect.ValueOf(newExpr))
911+
}
912+
newExpr := expr.WalkTableExpr(v)
851913
return newExpr, (reflect.ValueOf(expr) != reflect.ValueOf(newExpr))
852914
}
853915

@@ -2082,9 +2144,9 @@ var _ walkableStmt = &ValuesClause{}
20822144
// by WalkExpr.
20832145
//
20842146
// NOTE: Beware that WalkStmt does not necessarily traverse all parts of a
2085-
// statement by itself. For example, it will not walk into Subquery nodes
2086-
// within a FROM clause or into a JoinCond. Walk's logic is pretty
2087-
// interdependent with the logic for constructing a query plan.
2147+
// statement by itself. For example, it will not walk into Subquery nodes within
2148+
// a FROM clause or into a JoinCond (unless using an ExtendedVisitor). Walk's
2149+
// logic is pretty interdependent with the logic for constructing a query plan.
20882150
func WalkStmt(v Visitor, stmt Statement) (newStmt Statement, changed bool) {
20892151
walkable, ok := stmt.(walkableStmt)
20902152
if !ok {
@@ -2114,14 +2176,14 @@ func (v *simpleVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) {
21142176

21152177
func (*simpleVisitor) VisitPost(expr Expr) Expr { return expr }
21162178

2117-
// SimpleVisitFn is a function that is run for every node in the VisitPre stage;
2118-
// see SimpleVisit.
2179+
// SimpleVisitFn is a function that is run for every Expr node in the VisitPre
2180+
// stage; see SimpleVisit.
21192181
type SimpleVisitFn func(expr Expr) (recurse bool, newExpr Expr, err error)
21202182

21212183
// SimpleVisit is a convenience wrapper for visitors that only have VisitPre
21222184
// code and don't return any results except an error. The given function is
2123-
// called in VisitPre for every node. The visitor stops as soon as an error is
2124-
// returned.
2185+
// called in VisitPre for every Expr node. The visitor stops as soon as an error
2186+
// is returned.
21252187
func SimpleVisit(expr Expr, preFn SimpleVisitFn) (Expr, error) {
21262188
v := simpleVisitor{fn: preFn}
21272189
newExpr, _ := WalkExpr(&v, expr)
@@ -2131,10 +2193,10 @@ func SimpleVisit(expr Expr, preFn SimpleVisitFn) (Expr, error) {
21312193
return newExpr, nil
21322194
}
21332195

2134-
// SimpleStmtVisit is a convenience wrapper for visitors that want to visit
2135-
// all part of a statement, only have VisitPre code and don't return
2136-
// any results except an error. The given function is called in VisitPre
2137-
// for every node. The visitor stops as soon as an error is returned.
2196+
// SimpleStmtVisit is a convenience wrapper for visitors that want to visit all
2197+
// part of a statement, only have VisitPre code, and don't return any results
2198+
// except an error. The given function is called in VisitPre for every Expr
2199+
// node. The visitor stops as soon as an error is returned.
21382200
func SimpleStmtVisit(stmt Statement, preFn SimpleVisitFn) (Statement, error) {
21392201
v := simpleVisitor{fn: preFn}
21402202
newStmt, changed := WalkStmt(&v, stmt)
@@ -2147,6 +2209,68 @@ func SimpleStmtVisit(stmt Statement, preFn SimpleVisitFn) (Statement, error) {
21472209
return stmt, nil
21482210
}
21492211

2212+
type extendedSimpleVisitor struct {
2213+
simpleVisitor
2214+
efn ExtendedSimpleVisitFn
2215+
}
2216+
2217+
var _ ExtendedVisitor = &extendedSimpleVisitor{}
2218+
2219+
func (ev *extendedSimpleVisitor) VisitTablePre(expr TableExpr) (recurse bool, newExpr TableExpr) {
2220+
if ev.err != nil {
2221+
return false, expr
2222+
}
2223+
recurse, newExpr, ev.err = ev.efn(expr)
2224+
if ev.err != nil {
2225+
return false, expr
2226+
}
2227+
return recurse, newExpr
2228+
}
2229+
2230+
func (ev *extendedSimpleVisitor) VisitTablePost(expr TableExpr) (newNode TableExpr) { return expr }
2231+
2232+
// ExtendedSimpleVisitFn is a function that is run for every TableExpr node in
2233+
// the VisitTablePre stage; see ExtendedSimpleVisit.
2234+
type ExtendedSimpleVisitFn func(expr TableExpr) (recurse bool, newExpr TableExpr, err error)
2235+
2236+
// ExtendedSimpleVisit is a convenience wrapper for visitors that only have
2237+
// VisitPre and VisitTablePre code, and don't return any results except an
2238+
// error. The given functions are called in VisitPre for every Expr node and
2239+
// VisitTablePre for every TableExpr node, respectively. The visitor stops as
2240+
// soon as an error is returned.
2241+
//
2242+
// ExtendedSimpleVisit is identical to SimpleVisit but also handles TableExpr
2243+
// nodes.
2244+
func ExtendedSimpleVisit(
2245+
expr Expr, preFn SimpleVisitFn, preTableFn ExtendedSimpleVisitFn,
2246+
) (Expr, error) {
2247+
ev := extendedSimpleVisitor{simpleVisitor{fn: preFn}, preTableFn}
2248+
newExpr, _ := WalkExpr(&ev, expr)
2249+
if ev.err != nil {
2250+
return nil, ev.err
2251+
}
2252+
return newExpr, nil
2253+
}
2254+
2255+
// ExtendedSimpleStmtVisit is a convenience wrapper for visitors that want to
2256+
// visit all part of a statement, only have VisitPre and VisitTablePre code, and
2257+
// don't return any results except an error. The given functions are called in
2258+
// VisitPre for every Expr node and VisitTablePre for every TableExpr node,
2259+
// respectively. The visitor stops as soon as an error is returned.
2260+
func ExtendedSimpleStmtVisit(
2261+
stmt Statement, preFn SimpleVisitFn, preTableFn ExtendedSimpleVisitFn,
2262+
) (Statement, error) {
2263+
ev := extendedSimpleVisitor{simpleVisitor{fn: preFn}, preTableFn}
2264+
newStmt, changed := WalkStmt(&ev, stmt)
2265+
if ev.err != nil {
2266+
return nil, ev.err
2267+
}
2268+
if changed {
2269+
return newStmt, nil
2270+
}
2271+
return stmt, nil
2272+
}
2273+
21502274
type debugVisitor struct {
21512275
buf bytes.Buffer
21522276
level int

0 commit comments

Comments
 (0)