Skip to content

Commit 9349733

Browse files
chirag-s-dbhuangxiaopingRD
authored andcommitted
[SPARK-53322][SQL] Select a KeyGroupedShuffleSpec only when join key positions can be fully pushed down
### What changes were proposed in this pull request? When a KeyGroupedShuffleSpec is used to shuffle another child of a JOIN, we must be able to push down JOIN keys or partition values to be able to ensure that both children have matching partitioning. If one child reports a KeyGroupedPartitioning but we can't push down these values (for example, if the child was a key-grouped scan that was checkpointed), then this information cannot be pushed down to the child scan and we should avoid using this shuffle spec to shuffle other children. ### Why are the changes needed? Prevents potential correctness issue when key-grouped partitioning is used on a checkpointed RDD. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? See test changes. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#53098 from chirag-s-db/checkpoint-pushdown. Lead-authored-by: Chirag Singh <chirag.singh@databricks.com> Co-authored-by: Chirag Singh <137233133+chirag-s-db@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 29a9825 commit 9349733

File tree

3 files changed

+218
-33
lines changed

3 files changed

+218
-33
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ case class EnsureRequirements(
140140
// Choose all the specs that can be used to shuffle other children
141141
val candidateSpecs = specs
142142
.filter(_._2.canCreatePartitioning)
143+
.filter {
144+
// To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into
145+
// the scan (for join key positions). If these parameters can't be pushed down, this
146+
// spec can't be used to shuffle other children.
147+
case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx))
148+
case _ => true
149+
}
143150
.filter(p => !shouldConsiderMinParallelism ||
144151
children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions)
145152
val bestSpecOpt = if (candidateSpecs.isEmpty) {
@@ -402,6 +409,24 @@ case class EnsureRequirements(
402409
}
403410
}
404411

412+
/**
413+
* Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be
414+
* eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware
415+
* scan.
416+
*
417+
* Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the
418+
* reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be
419+
* eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily
420+
* re-grouped or padded with empty partitions according to the partition values on the other side
421+
* of the join).
422+
*/
423+
private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = {
424+
plan.collectLeaves().forall {
425+
case _: KeyGroupedPartitionedScan[_] => true
426+
case _ => false
427+
}
428+
}
429+
405430
/**
406431
* Checks whether two children, `left` and `right`, of a join operator have compatible
407432
* `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
@@ -413,6 +438,12 @@ case class EnsureRequirements(
413438
left: SparkPlan,
414439
right: SparkPlan,
415440
requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
441+
// If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an
442+
// SPJ.
443+
if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) {
444+
return None
445+
}
446+
416447
parent match {
417448
case smj: SortMergeJoinExec =>
418449
checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution)

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
3434
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
3535
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
3636
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
37+
import org.apache.spark.sql.functions.{col, max}
3738
import org.apache.spark.sql.internal.SQLConf
3839
import org.apache.spark.sql.internal.SQLConf._
3940
import org.apache.spark.sql.types._
@@ -2626,4 +2627,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
26262627
assert(scans.forall(_.inputRDD.partitions.length == 2))
26272628
}
26282629
}
2630+
2631+
test("SPARK-53322: checkpointed scans avoid shuffles for aggregates") {
2632+
withTempDir { dir =>
2633+
spark.sparkContext.setCheckpointDir(dir.getPath)
2634+
val itemsPartitions = Array(identity("id"))
2635+
createTable(items, itemsColumns, itemsPartitions)
2636+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2637+
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2638+
s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
2639+
s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
2640+
s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
2641+
2642+
val scanDF = spark.read.table(s"testcat.ns.$items").checkpoint()
2643+
val df = scanDF.groupBy("id").agg(max("price").as("res")).select("res")
2644+
checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))
2645+
2646+
val shuffles = collectAllShuffles(df.queryExecution.executedPlan)
2647+
assert(shuffles.isEmpty,
2648+
"should not contain shuffle when not grouping by partition values")
2649+
}
2650+
}
2651+
2652+
test("SPARK-53322: checkpointed scans aren't used for SPJ") {
2653+
withTempDir { dir =>
2654+
spark.sparkContext.setCheckpointDir(dir.getPath)
2655+
val itemsPartitions = Array(identity("id"))
2656+
createTable(items, itemsColumns, itemsPartitions)
2657+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2658+
s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
2659+
s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
2660+
s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))")
2661+
2662+
val purchase_partitions = Array(identity("item_id"))
2663+
createTable(purchases, purchasesColumns, purchase_partitions)
2664+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
2665+
s"(1, 40.0, cast('2020-01-01' as timestamp)), " +
2666+
s"(3, 25.5, cast('2020-01-03' as timestamp)), " +
2667+
s"(4, 20.0, cast('2020-01-04' as timestamp))")
2668+
2669+
for {
2670+
pushdownValues <- Seq(true, false)
2671+
checkpointBothScans <- Seq(true, false)
2672+
} {
2673+
withSQLConf(
2674+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
2675+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) {
2676+
val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i")
2677+
val scanDF2 = if (checkpointBothScans) {
2678+
spark.read.table(s"testcat.ns.$purchases").checkpoint().as("p")
2679+
} else {
2680+
spark.read.table(s"testcat.ns.$purchases").as("p")
2681+
}
2682+
2683+
val df = scanDF1
2684+
.join(scanDF2, col("id") === col("item_id"))
2685+
.selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price")
2686+
.orderBy("id", "purchase_price", "sale_price")
2687+
checkAnswer(
2688+
df,
2689+
Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5))
2690+
)
2691+
// 1 shuffle for SORT and 2 shuffles for JOIN are expected.
2692+
assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3)
2693+
}
2694+
}
2695+
}
2696+
}
2697+
2698+
test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") {
2699+
withTempDir { dir =>
2700+
spark.sparkContext.setCheckpointDir(dir.getPath)
2701+
val itemsPartitions = Array(identity("id"))
2702+
createTable(items, itemsColumns, itemsPartitions)
2703+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2704+
s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
2705+
s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
2706+
s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))")
2707+
2708+
createTable(purchases, purchasesColumns, Array.empty)
2709+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
2710+
s"(1, 40.0, cast('2020-01-01' as timestamp)), " +
2711+
s"(3, 25.5, cast('2020-01-03' as timestamp)), " +
2712+
s"(4, 20.0, cast('2020-01-04' as timestamp))")
2713+
2714+
Seq(true, false).foreach { pushdownValues =>
2715+
withSQLConf(
2716+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
2717+
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
2718+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) {
2719+
val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i")
2720+
val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p")
2721+
2722+
val df = scanDF1
2723+
.join(scanDF2, col("id") === col("item_id"))
2724+
.selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price")
2725+
.orderBy("id", "purchase_price", "sale_price")
2726+
checkAnswer(
2727+
df,
2728+
Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5))
2729+
)
2730+
// 1 shuffle for SORT and 2 shuffles for JOIN are expected.
2731+
assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3)
2732+
}
2733+
}
2734+
}
2735+
}
2736+
2737+
test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") {
2738+
withTempDir { dir =>
2739+
spark.sparkContext.setCheckpointDir(dir.getPath)
2740+
val itemsPartitions = Array(identity("id"))
2741+
createTable(items, itemsColumns, itemsPartitions)
2742+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2743+
s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
2744+
s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
2745+
s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))")
2746+
2747+
createTable(purchases, purchasesColumns, Array(identity("item_id")))
2748+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
2749+
s"(1, 40.0, cast('2020-01-01' as timestamp)), " +
2750+
s"(3, 25.5, cast('2020-01-03' as timestamp)), " +
2751+
s"(4, 20.0, cast('2020-01-04' as timestamp))")
2752+
2753+
withSQLConf(
2754+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
2755+
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
2756+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") {
2757+
val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i")
2758+
val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p")
2759+
2760+
val df = scanDF1
2761+
.join(scanDF2, col("id") === col("item_id"))
2762+
.selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price")
2763+
.orderBy("id", "purchase_price", "sale_price")
2764+
checkAnswer(
2765+
df,
2766+
Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5))
2767+
)
2768+
2769+
// One shuffle for the sort and one shuffle for one side of the JOIN are expected.
2770+
assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2)
2771+
}
2772+
}
2773+
}
26292774
}

0 commit comments

Comments
 (0)