@@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
3434import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2ScanRelation
3535import org .apache .spark .sql .execution .exchange .{ShuffleExchangeExec , ShuffleExchangeLike }
3636import org .apache .spark .sql .execution .joins .SortMergeJoinExec
37+ import org .apache .spark .sql .functions .{col , max }
3738import org .apache .spark .sql .internal .SQLConf
3839import org .apache .spark .sql .internal .SQLConf ._
3940import org .apache .spark .sql .types ._
@@ -2626,4 +2627,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
26262627 assert(scans.forall(_.inputRDD.partitions.length == 2 ))
26272628 }
26282629 }
2630+
2631+ test(" SPARK-53322: checkpointed scans avoid shuffles for aggregates" ) {
2632+ withTempDir { dir =>
2633+ spark.sparkContext.setCheckpointDir(dir.getPath)
2634+ val itemsPartitions = Array (identity(" id" ))
2635+ createTable(items, itemsColumns, itemsPartitions)
2636+ sql(s " INSERT INTO testcat.ns. $items VALUES " +
2637+ s " (1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2638+ s " (1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
2639+ s " (2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
2640+ s " (3, 'cc', 15.5, cast('2020-02-01' as timestamp)) " )
2641+
2642+ val scanDF = spark.read.table(s " testcat.ns. $items" ).checkpoint()
2643+ val df = scanDF.groupBy(" id" ).agg(max(" price" ).as(" res" )).select(" res" )
2644+ checkAnswer(df.sort(" res" ), Seq (Row (10.0 ), Row (15.5 ), Row (41.0 )))
2645+
2646+ val shuffles = collectAllShuffles(df.queryExecution.executedPlan)
2647+ assert(shuffles.isEmpty,
2648+ " should not contain shuffle when not grouping by partition values" )
2649+ }
2650+ }
2651+
2652+ test(" SPARK-53322: checkpointed scans aren't used for SPJ" ) {
2653+ withTempDir { dir =>
2654+ spark.sparkContext.setCheckpointDir(dir.getPath)
2655+ val itemsPartitions = Array (identity(" id" ))
2656+ createTable(items, itemsColumns, itemsPartitions)
2657+ sql(s " INSERT INTO testcat.ns. $items VALUES " +
2658+ s " (1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
2659+ s " (2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
2660+ s " (3, 'cc', 15.5, cast('2020-01-03' as timestamp)) " )
2661+
2662+ val purchase_partitions = Array (identity(" item_id" ))
2663+ createTable(purchases, purchasesColumns, purchase_partitions)
2664+ sql(s " INSERT INTO testcat.ns. $purchases VALUES " +
2665+ s " (1, 40.0, cast('2020-01-01' as timestamp)), " +
2666+ s " (3, 25.5, cast('2020-01-03' as timestamp)), " +
2667+ s " (4, 20.0, cast('2020-01-04' as timestamp)) " )
2668+
2669+ for {
2670+ pushdownValues <- Seq (true , false )
2671+ checkpointBothScans <- Seq (true , false )
2672+ } {
2673+ withSQLConf(
2674+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
2675+ SQLConf .V2_BUCKETING_PUSH_PART_VALUES_ENABLED .key -> pushdownValues.toString) {
2676+ val scanDF1 = spark.read.table(s " testcat.ns. $items" ).checkpoint().as(" i" )
2677+ val scanDF2 = if (checkpointBothScans) {
2678+ spark.read.table(s " testcat.ns. $purchases" ).checkpoint().as(" p" )
2679+ } else {
2680+ spark.read.table(s " testcat.ns. $purchases" ).as(" p" )
2681+ }
2682+
2683+ val df = scanDF1
2684+ .join(scanDF2, col(" id" ) === col(" item_id" ))
2685+ .selectExpr(" id" , " name" , " i.price AS purchase_price" , " p.price AS sale_price" )
2686+ .orderBy(" id" , " purchase_price" , " sale_price" )
2687+ checkAnswer(
2688+ df,
2689+ Seq (Row (1 , " aa" , 41.0 , 40.0 ), Row (3 , " cc" , 15.5 , 25.5 ))
2690+ )
2691+ // 1 shuffle for SORT and 2 shuffles for JOIN are expected.
2692+ assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3 )
2693+ }
2694+ }
2695+ }
2696+ }
2697+
2698+ test(" SPARK-53322: checkpointed scans can't shuffle other children on SPJ" ) {
2699+ withTempDir { dir =>
2700+ spark.sparkContext.setCheckpointDir(dir.getPath)
2701+ val itemsPartitions = Array (identity(" id" ))
2702+ createTable(items, itemsColumns, itemsPartitions)
2703+ sql(s " INSERT INTO testcat.ns. $items VALUES " +
2704+ s " (1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
2705+ s " (2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
2706+ s " (3, 'cc', 15.5, cast('2020-01-03' as timestamp)) " )
2707+
2708+ createTable(purchases, purchasesColumns, Array .empty)
2709+ sql(s " INSERT INTO testcat.ns. $purchases VALUES " +
2710+ s " (1, 40.0, cast('2020-01-01' as timestamp)), " +
2711+ s " (3, 25.5, cast('2020-01-03' as timestamp)), " +
2712+ s " (4, 20.0, cast('2020-01-04' as timestamp)) " )
2713+
2714+ Seq (true , false ).foreach { pushdownValues =>
2715+ withSQLConf(
2716+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
2717+ SQLConf .V2_BUCKETING_SHUFFLE_ENABLED .key -> " true" ,
2718+ SQLConf .V2_BUCKETING_PUSH_PART_VALUES_ENABLED .key -> pushdownValues.toString) {
2719+ val scanDF1 = spark.read.table(s " testcat.ns. $items" ).checkpoint().as(" i" )
2720+ val scanDF2 = spark.read.table(s " testcat.ns. $purchases" ).as(" p" )
2721+
2722+ val df = scanDF1
2723+ .join(scanDF2, col(" id" ) === col(" item_id" ))
2724+ .selectExpr(" id" , " name" , " i.price AS purchase_price" , " p.price AS sale_price" )
2725+ .orderBy(" id" , " purchase_price" , " sale_price" )
2726+ checkAnswer(
2727+ df,
2728+ Seq (Row (1 , " aa" , 41.0 , 40.0 ), Row (3 , " cc" , 15.5 , 25.5 ))
2729+ )
2730+ // 1 shuffle for SORT and 2 shuffles for JOIN are expected.
2731+ assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3 )
2732+ }
2733+ }
2734+ }
2735+ }
2736+
2737+ test(" SPARK-53322: checkpointed scans can be shuffled by children on SPJ" ) {
2738+ withTempDir { dir =>
2739+ spark.sparkContext.setCheckpointDir(dir.getPath)
2740+ val itemsPartitions = Array (identity(" id" ))
2741+ createTable(items, itemsColumns, itemsPartitions)
2742+ sql(s " INSERT INTO testcat.ns. $items VALUES " +
2743+ s " (1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
2744+ s " (2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
2745+ s " (3, 'cc', 15.5, cast('2020-01-03' as timestamp)) " )
2746+
2747+ createTable(purchases, purchasesColumns, Array (identity(" item_id" )))
2748+ sql(s " INSERT INTO testcat.ns. $purchases VALUES " +
2749+ s " (1, 40.0, cast('2020-01-01' as timestamp)), " +
2750+ s " (3, 25.5, cast('2020-01-03' as timestamp)), " +
2751+ s " (4, 20.0, cast('2020-01-04' as timestamp)) " )
2752+
2753+ withSQLConf(
2754+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
2755+ SQLConf .V2_BUCKETING_SHUFFLE_ENABLED .key -> " true" ,
2756+ SQLConf .V2_BUCKETING_PUSH_PART_VALUES_ENABLED .key -> " true" ) {
2757+ val scanDF1 = spark.read.table(s " testcat.ns. $items" ).checkpoint().as(" i" )
2758+ val scanDF2 = spark.read.table(s " testcat.ns. $purchases" ).as(" p" )
2759+
2760+ val df = scanDF1
2761+ .join(scanDF2, col(" id" ) === col(" item_id" ))
2762+ .selectExpr(" id" , " name" , " i.price AS purchase_price" , " p.price AS sale_price" )
2763+ .orderBy(" id" , " purchase_price" , " sale_price" )
2764+ checkAnswer(
2765+ df,
2766+ Seq (Row (1 , " aa" , 41.0 , 40.0 ), Row (3 , " cc" , 15.5 , 25.5 ))
2767+ )
2768+
2769+ // One shuffle for the sort and one shuffle for one side of the JOIN are expected.
2770+ assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2 )
2771+ }
2772+ }
2773+ }
26292774}
0 commit comments