Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,78 +27,126 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

/**
* This rule tries to merge multiple non-correlated [[ScalarSubquery]]s to compute multiple scalar
* values once.
* This rule tries to merge multiple subplans that have one row result. This can be either the plan
* tree of a [[ScalarSubquery]] expression or the plan tree starting at a non-grouping [[Aggregate]]
* node.
*
* The process is the following:
* - While traversing through the plan each [[ScalarSubquery]] plan is tried to merge into already
* seen subquery plans using `PlanMerger`s.
* - While traversing through the plan each one row returning subplan is tried to merge into already
* seen one row returning subplans using `PlanMerger`s.
* During this first traversal each [[ScalarSubquery]] expression is replaced to a temporal
* [[ScalarSubqueryReference]] pointing to its possible merged version stored in `PlanMerger`s.
* `PlanMerger`s keep track of whether a plan is a result of merging 2 or more plans, or is an
* original unmerged plan. [[ScalarSubqueryReference]]s contain all the required information to
* either restore the original [[ScalarSubquery]] or create a reference to a merged CTE.
* - Once the first traversal is complete and all possible merging have been done a second traversal
* removes the [[ScalarSubqueryReference]]s to either restore the original [[ScalarSubquery]] or
* to replace the original to a modified one that references a CTE with a merged plan.
* [[ScalarSubqueryReference]] and each non-grouping [[Aggregate]] node is replaced to a temporal
* [[NonGroupingAggregateReference]] pointing to its possible merged version in `PlanMerger`s.
* `PlanMerger`s keep track of whether a plan is a result of merging 2 or more subplans, or is an
* original unmerged plan.
* [[ScalarSubqueryReference]]s and [[NonGroupingAggregateReference]]s contain all the required
* information to either restore the original subplan or create a reference to a merged CTE.
* - Once the first traversal is complete and all possible merging have been done, a second
* traversal removes the references to either restore the original subplans or to replace the
* original to a modified ones that reference a CTE with a merged plan.
* A modified [[ScalarSubquery]] is constructed like:
* `GetStructField(ScalarSubquery(CTERelationRef(...)), outputIndex)` where `outputIndex` is the
* index of the output attribute (of the CTE) that corresponds to the output of the original
* subquery.
* `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)`
* while a modified [[Aggregate]] is constructed like:
* ```
* Project(
* Seq(
* GetStructField(
* ScalarSubquery(CTERelationRef to the merged plan),
* merged output index 1),
* GetStructField(
* ScalarSubquery(CTERelationRef to the merged plan),
* merged output index 2),
* ...),
* OneRowRelation)
* ```
* where `merged output index`s are the index of the output attributes (of the CTE) that
* correspond to the output of the original node.
* - If there are merged subqueries in `PlanMerger`s then a `WithCTE` node is built from these
* queries. The `CTERelationDef` nodes contain the merged subquery in the following form:
* `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubqueryPlan)`.
* The definitions are flagged that they host a subquery, that can return maximum one row.
* queries. The `CTERelationDef` nodes contain the merged subplans in the following form:
* `Project(Seq(CreateNamedStruct(name 1, attribute 1, ...) AS mergedValue), mergedSubplan)`.
*
* Eg. the following query:
* Here are a few examples:
*
* SELECT
* (SELECT avg(a) FROM t),
* (SELECT sum(b) FROM t)
*
* is optimized from:
*
* == Optimized Logical Plan ==
* Project [scalar-subquery#242 [] AS scalarsubquery()#253,
* scalar-subquery#243 [] AS scalarsubquery()#254L]
* : :- Aggregate [avg(a#244) AS avg(a)#247]
* : : +- Project [a#244]
* : : +- Relation default.t[a#244,b#245] parquet
* : +- Aggregate [sum(a#251) AS sum(a)#250L]
* : +- Project [a#251]
* : +- Relation default.t[a#251,b#252] parquet
* 1. a query with 2 subqueries:
* ```
* Project [scalar-subquery [] AS scalarsubquery(), scalar-subquery [] AS scalarsubquery()]
* : :- Aggregate [min(a) AS min(a)]
* : : +- Relation [a, b, c]
* : +- Aggregate [sum(b) AS sum(b)]
* : +- Relation [a, b, c]
* +- OneRowRelation
* ```
* is optimized to:
* ```
* WithCTE
* :- CTERelationDef 0
* : +- Project [named_struct(min(a), min(a), sum(b), sum(b)) AS mergedValue]
* : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b)]
* : +- Relation [a, b, c]
* +- Project [scalar-subquery [].min(a) AS scalarsubquery(),
* scalar-subquery [].sum(b) AS scalarsubquery()]
* : :- CTERelationRef 0
* : +- CTERelationRef 0
* +- OneRowRelation
* ```
*
* to:
* 2. a query with 2 non-grouping aggregates:
* ```
* Join Inner
* :- Aggregate [min(a) AS min(a)]
* : +- Relation [a, b, c]
* +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
* +- Relation [a, b, c]
* ```
* is optimized to:
* ```
* WithCTE
* :- CTERelationDef 0
* : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue]
* : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
* : +- Relation [a, b, c]
* +- Join Inner
* :- Project [scalar-subquery [].min(a) AS min(a)]
* : : +- CTERelationRef 0
* : +- OneRowRelation
* +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)]
* : :- CTERelationRef 0
* : +- CTERelationRef 0
* +- OneRowRelation
* ```
*
* == Optimized Logical Plan ==
* Project [scalar-subquery#242 [].avg(a) AS scalarsubquery()#253,
* scalar-subquery#243 [].sum(a) AS scalarsubquery()#254L]
* : :- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
* : : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L]
* : : +- Project [a#244]
* : : +- Relation default.t[a#244,b#245] parquet
* : +- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
* : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L]
* : +- Project [a#244]
* : +- Relation default.t[a#244,b#245] parquet
* +- OneRowRelation
* 3. a query with a subquery and a non-grouping aggregate:
* ```
* Join Inner
* :- Project [scalar-subquery [] AS scalarsubquery()]
* : : +- Aggregate [min(a) AS min(a)]
* : : +- Relation [a, b, c]
* : +- OneRowRelation
* +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
* +- Relation [a, b, c]
* ```
* is optimized to:
* ```
* WithCTE
* :- CTERelationDef 0
* : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue]
* : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
* : +- Relation [a, b, c]
* +- Join Inner
* :- Project [scalar-subquery [].min(a) AS scalarsubquery()]
* : : +- CTERelationRef 0
* : +- OneRowRelation
* +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)]
* : :- CTERelationRef 0
* : +- CTERelationRef 0
* +- OneRowRelation
* ```
*
* == Physical Plan ==
* *(1) Project [Subquery scalar-subquery#242, [id=#125].avg(a) AS scalarsubquery()#253,
* ReusedSubquery
* Subquery scalar-subquery#242, [id=#125].sum(a) AS scalarsubquery()#254L]
* : :- Subquery scalar-subquery#242, [id=#125]
* : : +- *(2) Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
* : : +- *(2) HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)],
* output=[avg(a)#247, sum(a)#250L])
* : : +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#120]
* : : +- *(1) HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)],
* output=[sum#262, count#263L, sum#264L])
* : : +- *(1) ColumnarToRow
* : : +- FileScan parquet default.t[a#244] ...
* : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
* +- *(1) Scan OneRowRelation[]
* Please note that in the above examples the aggregations are part of a "join group", which could
* be rewritten as one aggregate without the need to introduce a CTE and keeping the join. But there
* are more complex cases when this CTE based approach is the only viable option. Such cases include
* when the aggregates reside at different parts of plan, maybe even in different subquery
* expressions.
*/
object MergeSubplans extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
Expand All @@ -123,7 +171,7 @@ object MergeSubplans extends Rule[LogicalPlan] {

// Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged
// ones. While traversing replace references in plans back to `CTERelationRef`s or to original
// plans. This is safe as a subplan at a level can reference only lower level ot other subplans.
// plans. This is safe as a subplan at a level can reference only lower level subplans.
val subplansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
planMergers.foreach { planMerger =>
val mergedPlans = planMerger.mergedPlans()
Expand Down Expand Up @@ -162,8 +210,9 @@ object MergeSubplans extends Rule[LogicalPlan] {
}

// First traversal inserts `ScalarSubqueryReference`s and `NoGroupingAggregateReference`s to the
// plan and tries to merge subplans by each level. Levels are separated eiter by scalar subqueries
// or by non-grouping aggregate nodes. Nodes with the same level make sense to try merging.
// plan and tries to merge subplans by each level. Levels are separated either by scalar
// subqueries or by non-grouping aggregate nodes. Nodes with the same level make sense to try
// merging.
private def insertReferences(
plan: LogicalPlan,
root: Boolean,
Expand Down Expand Up @@ -224,9 +273,10 @@ object MergeSubplans extends Rule[LogicalPlan] {
// parent
(aggregateReference, level + 1)
case o =>
val (newChildren, levels) = o.children.map(insertReferences(_, false, planMergers)).unzip
val (newChildren, levelsFromChildren) =
o.children.map(insertReferences(_, false, planMergers)).unzip
// Level is the maximum of the level from subqueries and the level from the children.
(o.withNewChildren(newChildren), (levelFromSubqueries +: levels).max)
(o.withNewChildren(newChildren), (levelFromSubqueries +: levelsFromChildren).max)
}

(planWithReferences, level)
Expand Down