@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030use super :: {
3131 DisplayAs , ExecutionPlanProperties , RecordBatchStream , SendableRecordBatchStream ,
3232} ;
33+ use crate :: coalesce:: LimitedBatchCoalescer ;
3334use crate :: execution_plan:: { CardinalityEffect , EvaluationType , SchedulingType } ;
3435use crate :: hash_utils:: create_hashes;
3536use crate :: metrics:: { BaselineMetrics , SpillMetrics } ;
@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933 spill_stream,
933934 1 , // Each receiver handles one input partition
934935 BaselineMetrics :: new ( & metrics, partition) ,
936+ context. session_config ( ) . batch_size ( ) ,
935937 ) ) as SendableRecordBatchStream
936938 } )
937939 . collect :: < Vec < _ > > ( ) ;
@@ -959,7 +961,6 @@ impl ExecutionPlan for RepartitionExec {
959961 . into_iter ( )
960962 . next ( )
961963 . expect ( "at least one spill reader should exist" ) ;
962-
963964 Ok ( Box :: pin ( PerPartitionStream :: new (
964965 schema_captured,
965966 rx. into_iter ( )
@@ -970,6 +971,7 @@ impl ExecutionPlan for RepartitionExec {
970971 spill_stream,
971972 num_input_partitions,
972973 BaselineMetrics :: new ( & metrics, partition) ,
974+ context. session_config ( ) . batch_size ( ) ,
973975 ) ) as SendableRecordBatchStream )
974976 }
975977 } )
@@ -1427,9 +1429,12 @@ struct PerPartitionStream {
14271429
14281430 /// Execution metrics
14291431 baseline_metrics : BaselineMetrics ,
1432+
1433+ batch_coalescer : LimitedBatchCoalescer ,
14301434}
14311435
14321436impl PerPartitionStream {
1437+ #[ allow( clippy:: too_many_arguments) ]
14331438 fn new (
14341439 schema : SchemaRef ,
14351440 receiver : DistributionReceiver < MaybeBatch > ,
@@ -1438,16 +1443,29 @@ impl PerPartitionStream {
14381443 spill_stream : SendableRecordBatchStream ,
14391444 num_input_partitions : usize ,
14401445 baseline_metrics : BaselineMetrics ,
1446+ batch_size : usize ,
14411447 ) -> Self {
14421448 Self {
1443- schema,
1449+ schema : Arc :: clone ( & schema ) ,
14441450 receiver,
14451451 _drop_helper : drop_helper,
14461452 reservation,
14471453 spill_stream,
14481454 state : StreamState :: ReadingMemory ,
14491455 remaining_partitions : num_input_partitions,
14501456 baseline_metrics,
1457+ batch_coalescer : LimitedBatchCoalescer :: new ( schema, batch_size, None ) ,
1458+ }
1459+ }
1460+
1461+ fn flush_remaining_batch (
1462+ & mut self ,
1463+ ) -> Poll < Option < std:: result:: Result < RecordBatch , DataFusionError > > > {
1464+ // Flush any remaining buffered batch
1465+ match self . batch_coalescer . finish ( ) {
1466+ Ok ( ( ) ) => Poll :: Ready ( self . batch_coalescer . next_completed_batch ( ) . map ( Ok ) ) ,
1467+
1468+ Err ( e) => Poll :: Ready ( Some ( Err ( e) ) ) ,
14511469 }
14521470 }
14531471
@@ -1460,75 +1478,82 @@ impl PerPartitionStream {
14601478 let _timer = cloned_time. timer ( ) ;
14611479
14621480 loop {
1463- match self . state {
1464- StreamState :: ReadingMemory => {
1465- // Poll the memory channel for next message
1466- let value = match self . receiver . recv ( ) . poll_unpin ( cx) {
1467- Poll :: Ready ( v) => v,
1468- Poll :: Pending => {
1469- // Nothing from channel, wait
1470- return Poll :: Pending ;
1471- }
1472- } ;
1473-
1474- match value {
1475- Some ( Some ( v) ) => match v {
1476- Ok ( RepartitionBatch :: Memory ( batch) ) => {
1477- // Release memory and return batch
1478- self . reservation
1479- . lock ( )
1480- . shrink ( batch. get_array_memory_size ( ) ) ;
1481- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1481+ loop {
1482+ match self . state {
1483+ StreamState :: ReadingMemory => {
1484+ // Poll the memory channel for next message
1485+ let value = match self . receiver . recv ( ) . poll_unpin ( cx) {
1486+ Poll :: Ready ( v) => v,
1487+ Poll :: Pending => {
1488+ // Nothing from channel, wait
1489+ return Poll :: Pending ;
14821490 }
1483- Ok ( RepartitionBatch :: Spilled ) => {
1484- // Batch was spilled, transition to reading from spill stream
1485- // We must block on spill stream until we get the batch
1486- // to preserve ordering
1487- self . state = StreamState :: ReadingSpilled ;
1491+ } ;
1492+
1493+ match value {
1494+ Some ( Some ( v) ) => match v {
1495+ Ok ( RepartitionBatch :: Memory ( batch) ) => {
1496+ // Release memory and return batch
1497+ self . reservation
1498+ . lock ( )
1499+ . shrink ( batch. get_array_memory_size ( ) ) ;
1500+ self . batch_coalescer . push_batch ( batch) ?;
1501+ break ;
1502+ }
1503+ Ok ( RepartitionBatch :: Spilled ) => {
1504+ // Batch was spilled, transition to reading from spill stream
1505+ // We must block on spill stream until we get the batch
1506+ // to preserve ordering
1507+ self . state = StreamState :: ReadingSpilled ;
1508+ continue ;
1509+ }
1510+ Err ( e) => {
1511+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
1512+ }
1513+ } ,
1514+ Some ( None ) => {
1515+ // One input partition finished
1516+ self . remaining_partitions -= 1 ;
1517+ if self . remaining_partitions == 0 {
1518+ // All input partitions finished
1519+ return self . flush_remaining_batch ( ) ;
1520+ }
1521+ // Continue to poll for more data from other partitions
14881522 continue ;
14891523 }
1490- Err ( e) => {
1491- return Poll :: Ready ( Some ( Err ( e) ) ) ;
1524+ None => {
1525+ // Channel closed unexpectedly
1526+ return self . flush_remaining_batch ( ) ;
14921527 }
1493- } ,
1494- Some ( None ) => {
1495- // One input partition finished
1496- self . remaining_partitions -= 1 ;
1497- if self . remaining_partitions == 0 {
1498- // All input partitions finished
1499- return Poll :: Ready ( None ) ;
1500- }
1501- // Continue to poll for more data from other partitions
1502- continue ;
1503- }
1504- None => {
1505- // Channel closed unexpectedly
1506- return Poll :: Ready ( None ) ;
15071528 }
15081529 }
1509- }
1510- StreamState :: ReadingSpilled => {
1511- // Poll spill stream for the spilled batch
1512- match self . spill_stream . poll_next_unpin ( cx) {
1513- Poll :: Ready ( Some ( Ok ( batch) ) ) => {
1514- self . state = StreamState :: ReadingMemory ;
1515- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1516- }
1517- Poll :: Ready ( Some ( Err ( e) ) ) => {
1518- return Poll :: Ready ( Some ( Err ( e) ) ) ;
1519- }
1520- Poll :: Ready ( None ) => {
1521- // Spill stream ended, keep draining the memory channel
1522- self . state = StreamState :: ReadingMemory ;
1523- }
1524- Poll :: Pending => {
1525- // Spilled batch not ready yet, must wait
1526- // This preserves ordering by blocking until spill data arrives
1527- return Poll :: Pending ;
1530+ StreamState :: ReadingSpilled => {
1531+ // Poll spill stream for the spilled batch
1532+ match self . spill_stream . poll_next_unpin ( cx) {
1533+ Poll :: Ready ( Some ( Ok ( batch) ) ) => {
1534+ self . state = StreamState :: ReadingMemory ;
1535+ self . batch_coalescer . push_batch ( batch) ?;
1536+ break ;
1537+ }
1538+ Poll :: Ready ( Some ( Err ( e) ) ) => {
1539+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
1540+ }
1541+ Poll :: Ready ( None ) => {
1542+ // Spill stream ended, keep draining the memory channel
1543+ self . state = StreamState :: ReadingMemory ;
1544+ }
1545+ Poll :: Pending => {
1546+ // Spilled batch not ready yet, must wait
1547+ // This preserves ordering by blocking until spill data arrives
1548+ return Poll :: Pending ;
1549+ }
15281550 }
15291551 }
15301552 }
15311553 }
1554+ if let Some ( batch) = self . batch_coalescer . next_completed_batch ( ) {
1555+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1556+ }
15321557 }
15331558 }
15341559}
@@ -1575,9 +1600,9 @@ mod tests {
15751600 use datafusion_common:: exec_err;
15761601 use datafusion_common:: test_util:: batches_to_sort_string;
15771602 use datafusion_common_runtime:: JoinSet ;
1603+ use datafusion_execution:: config:: SessionConfig ;
15781604 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
15791605 use insta:: assert_snapshot;
1580- use itertools:: Itertools ;
15811606
15821607 #[ tokio:: test]
15831608 async fn one_to_many_round_robin ( ) -> Result < ( ) > {
@@ -1588,7 +1613,7 @@ mod tests {
15881613
15891614 // repartition from 1 input to 4 output
15901615 let output_partitions =
1591- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 4 ) ) . await ?;
1616+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 4 ) , 8 ) . await ?;
15921617
15931618 assert_eq ! ( 4 , output_partitions. len( ) ) ;
15941619 assert_eq ! ( 13 , output_partitions[ 0 ] . len( ) ) ;
@@ -1608,7 +1633,7 @@ mod tests {
16081633
16091634 // repartition from 3 input to 1 output
16101635 let output_partitions =
1611- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 1 ) ) . await ?;
1636+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 1 ) , 8 ) . await ?;
16121637
16131638 assert_eq ! ( 1 , output_partitions. len( ) ) ;
16141639 assert_eq ! ( 150 , output_partitions[ 0 ] . len( ) ) ;
@@ -1625,7 +1650,7 @@ mod tests {
16251650
16261651 // repartition from 3 input to 5 output
16271652 let output_partitions =
1628- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) ) . await ?;
1653+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) , 8 ) . await ?;
16291654
16301655 assert_eq ! ( 5 , output_partitions. len( ) ) ;
16311656 assert_eq ! ( 30 , output_partitions[ 0 ] . len( ) ) ;
@@ -1648,6 +1673,7 @@ mod tests {
16481673 & schema,
16491674 partitions,
16501675 Partitioning :: Hash ( vec ! [ col( "c0" , & schema) ?] , 8 ) ,
1676+ 8 ,
16511677 )
16521678 . await ?;
16531679
@@ -1670,8 +1696,11 @@ mod tests {
16701696 schema : & SchemaRef ,
16711697 input_partitions : Vec < Vec < RecordBatch > > ,
16721698 partitioning : Partitioning ,
1699+ batch_size : usize ,
16731700 ) -> Result < Vec < Vec < RecordBatch > > > {
1674- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
1701+ let session_config = SessionConfig :: new ( ) . with_batch_size ( batch_size) ;
1702+ let task_ctx =
1703+ Arc :: new ( TaskContext :: default ( ) . with_session_config ( session_config) ) ;
16751704 // create physical plan
16761705 let exec =
16771706 TestMemoryExec :: try_new_exec ( & input_partitions, Arc :: clone ( schema) , None ) ?;
@@ -1702,7 +1731,8 @@ mod tests {
17021731 vec ! [ partition. clone( ) , partition. clone( ) , partition. clone( ) ] ;
17031732
17041733 // repartition from 3 input to 5 output
1705- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) ) . await
1734+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) , 8 )
1735+ . await
17061736 } ) ;
17071737
17081738 let output_partitions = handle. join ( ) . await . unwrap ( ) . unwrap ( ) ;
@@ -1898,7 +1928,9 @@ mod tests {
18981928 // with different compilers, we will compare the same execution with
18991929 // and without dropping the output stream.
19001930 async fn hash_repartition_with_dropping_output_stream ( ) {
1901- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
1931+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 4 ) ;
1932+ let task_ctx =
1933+ Arc :: new ( TaskContext :: default ( ) . with_session_config ( session_config) ) ;
19021934 let partitioning = Partitioning :: Hash (
19031935 vec ! [ Arc :: new( crate :: expressions:: Column :: new(
19041936 "my_awesome_field" ,
@@ -1950,14 +1982,14 @@ mod tests {
19501982 } ) ;
19511983 let batches_with_drop = crate :: common:: collect ( output_stream1) . await . unwrap ( ) ;
19521984
1953- fn sort ( batch : Vec < RecordBatch > ) -> Vec < RecordBatch > {
1954- batch
1955- . into_iter ( )
1956- . sorted_by_key ( |b| format ! ( "{b:?}" ) )
1957- . collect ( )
1958- }
1985+ let items_vec_with_drop = str_batches_to_vec ( & batches_with_drop) ;
1986+ let items_set_with_drop: HashSet < & str > =
1987+ items_vec_with_drop. iter ( ) . copied ( ) . collect ( ) ;
19591988
1960- assert_eq ! ( sort( batches_without_drop) , sort( batches_with_drop) ) ;
1989+ assert_eq ! (
1990+ items_set_with_drop. symmetric_difference( & items_set) . count( ) ,
1991+ 0
1992+ ) ;
19611993 }
19621994
19631995 fn str_batches_to_vec ( batches : & [ RecordBatch ] ) -> Vec < & str > {
@@ -2396,6 +2428,7 @@ mod test {
23962428 use arrow:: compute:: SortOptions ;
23972429 use arrow:: datatypes:: { DataType , Field , Schema } ;
23982430 use datafusion_common:: assert_batches_eq;
2431+ use datafusion_execution:: config:: SessionConfig ;
23992432
24002433 use super :: * ;
24012434 use crate :: test:: TestMemoryExec ;
@@ -2507,8 +2540,10 @@ mod test {
25072540 let runtime = RuntimeEnvBuilder :: default ( )
25082541 . with_memory_limit ( 64 , 1.0 )
25092542 . build_arc ( ) ?;
2510-
2511- let task_ctx = TaskContext :: default ( ) . with_runtime ( runtime) ;
2543+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 4 ) ;
2544+ let task_ctx = TaskContext :: default ( )
2545+ . with_runtime ( runtime)
2546+ . with_session_config ( session_config) ;
25122547 let task_ctx = Arc :: new ( task_ctx) ;
25132548
25142549 // Create physical plan with order preservation
0 commit comments