Skip to content

Commit 78fcc93

Browse files
committed
[SPARK-44571][SQL] Merge subplans with one row result
### What changes were proposed in this pull request? This PR renames `MergeScalarSubqueries` rule to `MergeSubplans` and extends plan merging capabilities to non-grouping aggregate subplans, which are very similar to scalar subqueries in terms they return one row result. Consider the following query that joins 2 non-grouping aggregates: ``` Join Inner :- Aggregate [min(a) AS min(a)] : +- Relation [a, b, c] +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] +- Relation [a, b, c] ``` with the improved rule the plan is optimized to: ``` WithCTE :- CTERelationDef 0 : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue] : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] : +- Relation [a, b, c] +- Join Inner :- Project [scalar-subquery [].min(a) AS min(a)] : : +- CTERelationRef 0 : +- OneRowRelation +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)] : :- CTERelationRef 0 : +- CTERelationRef 0 +- OneRowRelation ``` so as to scan `Relation` only once. Please note that the above plan where the 2 aggregations are part of a "join group" could be rewritten as one aggregate without the need to introduce a CTE and keeping the join. But there are more complex cases when the proposed CTE based approach is the only viable option. Such cases include when the aggregates reside at different parts of plan, maybe even in diffrent subquery expressions. E.g. the following query: ``` Join Inner :- Project [scalar-subquery [] AS scalarsubquery()] : : +- Aggregate [min(a) AS min(a)] : : +- Relation [a, b, c] : +- OneRowRelation +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] +- Relation [a, b, c] ``` can be optimized to: ``` WithCTE :- CTERelationDef 0 : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue] : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] : +- Relation [a, b, c] +- Join Inner :- Project [scalar-subquery [].min(a) AS scalarsubquery()] : : +- CTERelationRef 0 : +- OneRowRelation +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)] : :- CTERelationRef 0 : +- CTERelationRef 0 +- OneRowRelation ``` ### Why are the changes needed? To improve plan merging logic to further reduce redundant IO. Please also note that TPCDS q28 and q88 contain non-grouping aggregates, but this PR can't deal with them yet. Those queries will improve once [SPARK-40193](https://issues.apache.org/jira/browse/SPARK-40193) / #37630 lands in Spark. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Existing and new UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53019 from peter-toth/SPARK-44571-merge-subplans-with-one-row-result. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Peter Toth <peter.toth@gmail.com>
1 parent e8f0a67 commit 78fcc93

File tree

8 files changed

+623
-268
lines changed

8 files changed

+623
-268
lines changed
Lines changed: 134 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, Project, Subquery, WithCTE}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, LeafNode, LogicalPlan, OneRowRelation, Project, Subquery, WithCTE}
2424
import org.apache.spark.sql.catalyst.rules.Rule
25-
import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern}
25+
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern}
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types.DataType
2828

@@ -100,7 +100,7 @@ import org.apache.spark.sql.types.DataType
100100
* : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
101101
* +- *(1) Scan OneRowRelation[]
102102
*/
103-
object MergeScalarSubqueries extends Rule[LogicalPlan] {
103+
object MergeSubplans extends Rule[LogicalPlan] {
104104
def apply(plan: LogicalPlan): LogicalPlan = {
105105
plan match {
106106
// Subquery reuse needs to be enabled for this optimization.
@@ -117,26 +117,24 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
117117
}
118118

119119
private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
120-
// Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert references in place of
121-
// `ScalarSubquery`s.
120+
// Collect subplans by level into `PlanMerger`s and insert references in place of them.
122121
val planMergers = ArrayBuffer.empty[PlanMerger]
123-
val planWithReferences = insertReferences(plan, planMergers)._1
122+
val planWithReferences = insertReferences(plan, true, planMergers)._1
124123

125124
// Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged
126125
// ones. While traversing replace references in plans back to `CTERelationRef`s or to original
127-
// `ScalarSubquery`s. This is safe as a subquery plan at a level can reference only lower level
128-
// other subqueries.
129-
val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
126+
// plans. This is safe as a subplan at a level can reference only lower level ot other subplans.
127+
val subplansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
130128
planMergers.foreach { planMerger =>
131129
val mergedPlans = planMerger.mergedPlans()
132-
subqueryPlansByLevel += mergedPlans.map { mergedPlan =>
133-
val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) {
130+
subplansByLevel += mergedPlans.map { mergedPlan =>
131+
val planWithoutReferences = if (subplansByLevel.isEmpty) {
134132
// Level 0 plans can't contain references
135133
mergedPlan.plan
136134
} else {
137-
removeReferences(mergedPlan.plan, subqueryPlansByLevel)
135+
removeReferences(mergedPlan.plan, subplansByLevel)
138136
}
139-
if (mergedPlan.merged && mergedPlan.plan.output.size > 1) {
137+
if (mergedPlan.merged) {
140138
CTERelationDef(
141139
Project(
142140
Seq(Alias(
@@ -151,38 +149,42 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
151149
}
152150
}
153151

154-
// Replace references back to `CTERelationRef`s or to original `ScalarSubquery`s in the main
155-
// plan.
156-
val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel)
152+
// Replace references back to `CTERelationRef`s or to original subplans.
153+
val newPlan = removeReferences(planWithReferences, subplansByLevel)
157154

158155
// Add `CTERelationDef`s to the plan.
159-
val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte })
160-
if (subqueryCTEs.nonEmpty) {
161-
WithCTE(newPlan, subqueryCTEs.toSeq)
156+
val subplanCTEs = subplansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte })
157+
if (subplanCTEs.nonEmpty) {
158+
WithCTE(newPlan, subplanCTEs.toSeq)
162159
} else {
163160
newPlan
164161
}
165162
}
166163

167-
// First traversal inserts `ScalarSubqueryReference`s to the plan and tries to merge subquery
168-
// plans by each level.
164+
// First traversal inserts `ScalarSubqueryReference`s and `NoGroupingAggregateReference`s to the
165+
// plan and tries to merge subplans by each level. Levels are separated eiter by scalar subqueries
166+
// or by non-grouping aggregate nodes. Nodes with the same level make sense to try merging.
169167
private def insertReferences(
170168
plan: LogicalPlan,
169+
root: Boolean,
171170
planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = {
172-
// The level of a subquery plan is maximum level of its inner subqueries + 1 or 0 if it has no
173-
// inner subqueries.
174-
var maxLevel = 0
175-
val planWithReferences =
176-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
171+
if (!plan.containsAnyPattern(AGGREGATE, SCALAR_SUBQUERY)) {
172+
return (plan, 0)
173+
}
174+
175+
// Calculate the level propagated from subquery plans, which is the maximum level of the
176+
// subqueries of the node + 1 or 0 if the node has no subqueries.
177+
var levelFromSubqueries = 0
178+
val nodeSubqueriesWithReferences =
179+
plan.transformExpressionsWithPruning(_.containsPattern(SCALAR_SUBQUERY)) {
177180
case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
178-
val (planWithReferences, level) = insertReferences(s.plan, planMergers)
181+
val (planWithReferences, level) = insertReferences(s.plan, true, planMergers)
179182

180-
while (level >= planMergers.size) planMergers += new PlanMerger()
181183
// The subquery could contain a hint that is not propagated once we merge it, but as a
182184
// non-correlated scalar subquery won't be turned into a Join the loss of hints is fine.
183-
val mergeResult = planMergers(level).merge(planWithReferences)
185+
val mergeResult = getPlanMerger(planMergers, level).merge(planWithReferences, true)
184186

185-
maxLevel = maxLevel.max(level + 1)
187+
levelFromSubqueries = levelFromSubqueries.max(level + 1)
186188

187189
val mergedOutput = mergeResult.outputMap(planWithReferences.output.head)
188190
val outputIndex =
@@ -195,26 +197,96 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
195197
s.exprId)
196198
case o => o
197199
}
198-
(planWithReferences, maxLevel)
200+
201+
// Calculate the level of the node, which is the maximum of the above calculated level
202+
// propagated from subqueries and the level propagated from child nodes.
203+
val (planWithReferences, level) = nodeSubqueriesWithReferences match {
204+
case a: Aggregate if !root && a.groupingExpressions.isEmpty =>
205+
val (childWithReferences, levelFromChild) = insertReferences(a.child, false, planMergers)
206+
val aggregateWithReferences = a.withNewChildren(Seq(childWithReferences))
207+
208+
// Level is the maximum of the level from subqueries and the level from child.
209+
val level = levelFromChild.max(levelFromSubqueries)
210+
211+
val mergeResult = getPlanMerger(planMergers, level).merge(aggregateWithReferences, false)
212+
213+
val mergedOutput = aggregateWithReferences.output.map(mergeResult.outputMap)
214+
val outputIndices =
215+
mergedOutput.map(a => mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == a.exprId))
216+
val aggregateReference = NonGroupingAggregateReference(
217+
level,
218+
mergeResult.mergedPlanIndex,
219+
outputIndices,
220+
a.output
221+
)
222+
223+
// This is a non-grouping aggregate node so propagate the level of the node + 1 to its
224+
// parent
225+
(aggregateReference, level + 1)
226+
case o =>
227+
val (newChildren, levels) = o.children.map(insertReferences(_, false, planMergers)).unzip
228+
// Level is the maximum of the level from subqueries and the level from the children.
229+
(o.withNewChildren(newChildren), (levelFromSubqueries +: levels).max)
230+
}
231+
232+
(planWithReferences, level)
233+
}
234+
235+
private def getPlanMerger(planMergers: ArrayBuffer[PlanMerger], level: Int) = {
236+
while (level >= planMergers.size) planMergers += new PlanMerger()
237+
planMergers(level)
199238
}
200239

201-
// Second traversal replaces `ScalarSubqueryReference`s to either
202-
// `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the plan is merged from
203-
// multiple subqueries or `ScalarSubquery(original plan)` if it isn't.
240+
// Second traversal replaces:
241+
// - a `ScalarSubqueryReference` either to
242+
// `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)` if
243+
// the plan is merged from multiple subqueries or to `ScalarSubquery(original plan)` if it
244+
// isn't.
245+
// - a `NoGroupingAggregateReference` either to
246+
// ```
247+
// Project(
248+
// Seq(
249+
// GetStructField(
250+
// ScalarSubquery(CTERelationRef to the merged plan),
251+
// merged output index 1),
252+
// GetStructField(
253+
// ScalarSubquery(CTERelationRef to the merged plan),
254+
// merged output index 2),
255+
// ...),
256+
// OneRowRelation)
257+
// ```
258+
// if the plan is merged from multiple subqueries or to `original plan` if it isn't.
204259
private def removeReferences(
205260
plan: LogicalPlan,
206-
subqueryPlansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = {
207-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) {
208-
case ssr: ScalarSubqueryReference =>
209-
subqueryPlansByLevel(ssr.level)(ssr.mergedPlanIndex) match {
261+
subplansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = {
262+
plan.transformUpWithPruning(
263+
_.containsAnyPattern(NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY_REFERENCE)) {
264+
case ngar: NonGroupingAggregateReference =>
265+
subplansByLevel(ngar.level)(ngar.mergedPlanIndex) match {
210266
case cte: CTERelationDef =>
211-
GetStructField(
212-
ScalarSubquery(
213-
CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming),
214-
exprId = ssr.exprId),
215-
ssr.outputIndex)
216-
case o => ScalarSubquery(o, exprId = ssr.exprId)
267+
val projectList = ngar.outputIndices.zip(ngar.output).map { case (i, a) =>
268+
Alias(
269+
GetStructField(
270+
ScalarSubquery(
271+
CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming)),
272+
i),
273+
a.name)(a.exprId)
274+
}
275+
Project(projectList, OneRowRelation())
276+
case o => o
217277
}
278+
case o => o.transformExpressionsUpWithPruning(_.containsPattern(SCALAR_SUBQUERY_REFERENCE)) {
279+
case ssr: ScalarSubqueryReference =>
280+
subplansByLevel(ssr.level)(ssr.mergedPlanIndex) match {
281+
case cte: CTERelationDef =>
282+
GetStructField(
283+
ScalarSubquery(
284+
CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming),
285+
exprId = ssr.exprId),
286+
ssr.outputIndex)
287+
case o => ScalarSubquery(o, exprId = ssr.exprId)
288+
}
289+
}
218290
}
219291
}
220292
}
@@ -233,9 +305,26 @@ case class ScalarSubqueryReference(
233305
level: Int,
234306
mergedPlanIndex: Int,
235307
outputIndex: Int,
236-
dataType: DataType,
308+
override val dataType: DataType,
237309
exprId: ExprId) extends LeafExpression with Unevaluable {
238310
override def nullable: Boolean = true
239311

240312
final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE)
241313
}
314+
315+
/**
316+
* Temporal reference to a non-grouping aggregate which is added to a `PlanMerger`.
317+
*
318+
* @param level The level of the replaced aggregate. It defines the `PlanMerger` instance into which
319+
* the aggregate is merged.
320+
* @param mergedPlanIndex The index of the merged plan in the `PlanMerger`.
321+
* @param outputIndices The indices of the output attributes of the merged plan.
322+
* @param output The output of original aggregate.
323+
*/
324+
case class NonGroupingAggregateReference(
325+
level: Int,
326+
mergedPlanIndex: Int,
327+
outputIndices: Seq[Int],
328+
override val output: Seq[Attribute]) extends LeafNode {
329+
final override val nodePatterns: Seq[TreePattern] = Seq(NO_GROUPING_AGGREGATE_REFERENCE)
330+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean)
5858
* 2. Merge a new plan with a cached plan by combining their outputs
5959
*
6060
* The merging process preserves semantic equivalence while combining outputs from multiple
61-
* plans into a single plan. This is primarily used by [[MergeScalarSubqueries]] to deduplicate
62-
* scalar subquery execution.
61+
* plans into a single plan. This is primarily used by [[MergeSubplans]] to deduplicate subplan
62+
* execution.
6363
*
6464
* Supported plan types for merging:
6565
* - [[Project]]: Merges project lists
@@ -88,16 +88,21 @@ class PlanMerger {
8888
* 3. If no merge is possible, add as a new cache entry
8989
*
9090
* @param plan The logical plan to merge or cache.
91+
* @param subqueryPlan If the logical plan is a subquery plan.
9192
* @return A [[MergeResult]] containing:
9293
* - The merged/cached plan to use
9394
* - Its index in the cache
9495
* - An attribute mapping for rewriting expressions
9596
*/
96-
def merge(plan: LogicalPlan): MergeResult = {
97+
def merge(plan: LogicalPlan, subqueryPlan: Boolean): MergeResult = {
9798
cache.zipWithIndex.collectFirst(Function.unlift {
9899
case (mp, i) =>
99100
checkIdenticalPlans(plan, mp.plan).map { outputMap =>
100-
val newMergePlan = MergedPlan(mp.plan, true)
101+
// Identical subquery expression plans are not marked as `merged` as the
102+
// `ReusedSubqueryExec` rule can handle them without extracting the plans to CTEs.
103+
// But, when a non-subquery subplan is identical to a cached plan we need to mark the plan
104+
// `merged` and so extract it to a CTE later.
105+
val newMergePlan = MergedPlan(mp.plan, cache(i).merged || !subqueryPlan)
101106
cache(i) = newMergePlan
102107
MergeResult(newMergePlan, i, outputMap)
103108
}.orElse {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ object TreePattern extends Enumeration {
150150
val LOCAL_RELATION: Value = Value
151151
val LOGICAL_QUERY_STAGE: Value = Value
152152
val NATURAL_LIKE_JOIN: Value = Value
153+
val NO_GROUPING_AGGREGATE_REFERENCE: Value = Value
153154
val OFFSET: Value = Value
154155
val OUTER_JOIN: Value = Value
155156
val PARAMETERIZED_QUERY: Value = Value

0 commit comments

Comments
 (0)