Skip to content

Commit d53800e

Browse files
committed
feat: integrate batch coalescer with repartition exec
1 parent 5258352 commit d53800e

File tree

2 files changed

+138
-82
lines changed

2 files changed

+138
-82
lines changed

datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use datafusion::prelude::*;
3535
use datafusion::scalar::ScalarValue;
3636
use datafusion_catalog::Session;
3737
use datafusion_common::cast::as_primitive_array;
38-
use datafusion_common::{internal_err, not_impl_err};
38+
use datafusion_common::{internal_err, not_impl_err, DataFusionError};
3939
use datafusion_expr::expr::{BinaryExpr, Cast};
4040
use datafusion_functions_aggregate::expr_fn::count;
4141
use datafusion_physical_expr::EquivalenceProperties;
@@ -134,10 +134,31 @@ impl ExecutionPlan for CustomPlan {
134134
_partition: usize,
135135
_context: Arc<TaskContext>,
136136
) -> Result<SendableRecordBatchStream> {
137-
Ok(Box::pin(RecordBatchStreamAdapter::new(
138-
self.schema(),
139-
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
140-
)))
137+
if self.batches.is_empty() {
138+
Ok(Box::pin(RecordBatchStreamAdapter::new(
139+
self.schema(),
140+
futures::stream::empty(),
141+
)))
142+
} else {
143+
let schema_captured = self.schema().clone();
144+
Ok(Box::pin(RecordBatchStreamAdapter::new(
145+
self.schema(),
146+
futures::stream::iter(self.batches.clone().into_iter().map(
147+
move |batch| {
148+
let projection: Vec<usize> = schema_captured
149+
.fields()
150+
.iter()
151+
.filter_map(|field| {
152+
batch.schema().index_of(field.name()).ok()
153+
})
154+
.collect();
155+
batch
156+
.project(&projection)
157+
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
158+
},
159+
)),
160+
)))
161+
}
141162
}
142163

143164
fn statistics(&self) -> Result<Statistics> {

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 112 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030
use super::{
3131
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
3232
};
33+
use crate::coalesce::LimitedBatchCoalescer;
3334
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
3435
use crate::hash_utils::create_hashes;
3536
use crate::metrics::{BaselineMetrics, SpillMetrics};
@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933
spill_stream,
933934
1, // Each receiver handles one input partition
934935
BaselineMetrics::new(&metrics, partition),
936+
context.session_config().batch_size(),
935937
)) as SendableRecordBatchStream
936938
})
937939
.collect::<Vec<_>>();
@@ -959,7 +961,6 @@ impl ExecutionPlan for RepartitionExec {
959961
.into_iter()
960962
.next()
961963
.expect("at least one spill reader should exist");
962-
963964
Ok(Box::pin(PerPartitionStream::new(
964965
schema_captured,
965966
rx.into_iter()
@@ -970,6 +971,7 @@ impl ExecutionPlan for RepartitionExec {
970971
spill_stream,
971972
num_input_partitions,
972973
BaselineMetrics::new(&metrics, partition),
974+
context.session_config().batch_size(),
973975
)) as SendableRecordBatchStream)
974976
}
975977
})
@@ -1427,9 +1429,12 @@ struct PerPartitionStream {
14271429

14281430
/// Execution metrics
14291431
baseline_metrics: BaselineMetrics,
1432+
1433+
batch_coalescer: LimitedBatchCoalescer,
14301434
}
14311435

14321436
impl PerPartitionStream {
1437+
#[allow(clippy::too_many_arguments)]
14331438
fn new(
14341439
schema: SchemaRef,
14351440
receiver: DistributionReceiver<MaybeBatch>,
@@ -1438,16 +1443,29 @@ impl PerPartitionStream {
14381443
spill_stream: SendableRecordBatchStream,
14391444
num_input_partitions: usize,
14401445
baseline_metrics: BaselineMetrics,
1446+
batch_size: usize,
14411447
) -> Self {
14421448
Self {
1443-
schema,
1449+
schema: Arc::clone(&schema),
14441450
receiver,
14451451
_drop_helper: drop_helper,
14461452
reservation,
14471453
spill_stream,
14481454
state: StreamState::ReadingMemory,
14491455
remaining_partitions: num_input_partitions,
14501456
baseline_metrics,
1457+
batch_coalescer: LimitedBatchCoalescer::new(schema, batch_size, None),
1458+
}
1459+
}
1460+
1461+
fn flush_remaining_batch(
1462+
&mut self,
1463+
) -> Poll<Option<std::result::Result<RecordBatch, DataFusionError>>> {
1464+
// Flush any remaining buffered batch
1465+
match self.batch_coalescer.finish() {
1466+
Ok(()) => Poll::Ready(self.batch_coalescer.next_completed_batch().map(Ok)),
1467+
1468+
Err(e) => Poll::Ready(Some(Err(e))),
14511469
}
14521470
}
14531471

@@ -1460,75 +1478,82 @@ impl PerPartitionStream {
14601478
let _timer = cloned_time.timer();
14611479

14621480
loop {
1463-
match self.state {
1464-
StreamState::ReadingMemory => {
1465-
// Poll the memory channel for next message
1466-
let value = match self.receiver.recv().poll_unpin(cx) {
1467-
Poll::Ready(v) => v,
1468-
Poll::Pending => {
1469-
// Nothing from channel, wait
1470-
return Poll::Pending;
1471-
}
1472-
};
1473-
1474-
match value {
1475-
Some(Some(v)) => match v {
1476-
Ok(RepartitionBatch::Memory(batch)) => {
1477-
// Release memory and return batch
1478-
self.reservation
1479-
.lock()
1480-
.shrink(batch.get_array_memory_size());
1481-
return Poll::Ready(Some(Ok(batch)));
1481+
loop {
1482+
match self.state {
1483+
StreamState::ReadingMemory => {
1484+
// Poll the memory channel for next message
1485+
let value = match self.receiver.recv().poll_unpin(cx) {
1486+
Poll::Ready(v) => v,
1487+
Poll::Pending => {
1488+
// Nothing from channel, wait
1489+
return Poll::Pending;
14821490
}
1483-
Ok(RepartitionBatch::Spilled) => {
1484-
// Batch was spilled, transition to reading from spill stream
1485-
// We must block on spill stream until we get the batch
1486-
// to preserve ordering
1487-
self.state = StreamState::ReadingSpilled;
1491+
};
1492+
1493+
match value {
1494+
Some(Some(v)) => match v {
1495+
Ok(RepartitionBatch::Memory(batch)) => {
1496+
// Release memory and return batch
1497+
self.reservation
1498+
.lock()
1499+
.shrink(batch.get_array_memory_size());
1500+
self.batch_coalescer.push_batch(batch)?;
1501+
break;
1502+
}
1503+
Ok(RepartitionBatch::Spilled) => {
1504+
// Batch was spilled, transition to reading from spill stream
1505+
// We must block on spill stream until we get the batch
1506+
// to preserve ordering
1507+
self.state = StreamState::ReadingSpilled;
1508+
continue;
1509+
}
1510+
Err(e) => {
1511+
return Poll::Ready(Some(Err(e)));
1512+
}
1513+
},
1514+
Some(None) => {
1515+
// One input partition finished
1516+
self.remaining_partitions -= 1;
1517+
if self.remaining_partitions == 0 {
1518+
// All input partitions finished
1519+
return self.flush_remaining_batch();
1520+
}
1521+
// Continue to poll for more data from other partitions
14881522
continue;
14891523
}
1490-
Err(e) => {
1491-
return Poll::Ready(Some(Err(e)));
1524+
None => {
1525+
// Channel closed unexpectedly
1526+
return self.flush_remaining_batch();
14921527
}
1493-
},
1494-
Some(None) => {
1495-
// One input partition finished
1496-
self.remaining_partitions -= 1;
1497-
if self.remaining_partitions == 0 {
1498-
// All input partitions finished
1499-
return Poll::Ready(None);
1500-
}
1501-
// Continue to poll for more data from other partitions
1502-
continue;
1503-
}
1504-
None => {
1505-
// Channel closed unexpectedly
1506-
return Poll::Ready(None);
15071528
}
15081529
}
1509-
}
1510-
StreamState::ReadingSpilled => {
1511-
// Poll spill stream for the spilled batch
1512-
match self.spill_stream.poll_next_unpin(cx) {
1513-
Poll::Ready(Some(Ok(batch))) => {
1514-
self.state = StreamState::ReadingMemory;
1515-
return Poll::Ready(Some(Ok(batch)));
1516-
}
1517-
Poll::Ready(Some(Err(e))) => {
1518-
return Poll::Ready(Some(Err(e)));
1519-
}
1520-
Poll::Ready(None) => {
1521-
// Spill stream ended, keep draining the memory channel
1522-
self.state = StreamState::ReadingMemory;
1523-
}
1524-
Poll::Pending => {
1525-
// Spilled batch not ready yet, must wait
1526-
// This preserves ordering by blocking until spill data arrives
1527-
return Poll::Pending;
1530+
StreamState::ReadingSpilled => {
1531+
// Poll spill stream for the spilled batch
1532+
match self.spill_stream.poll_next_unpin(cx) {
1533+
Poll::Ready(Some(Ok(batch))) => {
1534+
self.state = StreamState::ReadingMemory;
1535+
self.batch_coalescer.push_batch(batch)?;
1536+
break;
1537+
}
1538+
Poll::Ready(Some(Err(e))) => {
1539+
return Poll::Ready(Some(Err(e)));
1540+
}
1541+
Poll::Ready(None) => {
1542+
// Spill stream ended, keep draining the memory channel
1543+
self.state = StreamState::ReadingMemory;
1544+
}
1545+
Poll::Pending => {
1546+
// Spilled batch not ready yet, must wait
1547+
// This preserves ordering by blocking until spill data arrives
1548+
return Poll::Pending;
1549+
}
15281550
}
15291551
}
15301552
}
15311553
}
1554+
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
1555+
return Poll::Ready(Some(Ok(batch)));
1556+
}
15321557
}
15331558
}
15341559
}
@@ -1575,9 +1600,9 @@ mod tests {
15751600
use datafusion_common::exec_err;
15761601
use datafusion_common::test_util::batches_to_sort_string;
15771602
use datafusion_common_runtime::JoinSet;
1603+
use datafusion_execution::config::SessionConfig;
15781604
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
15791605
use insta::assert_snapshot;
1580-
use itertools::Itertools;
15811606

15821607
#[tokio::test]
15831608
async fn one_to_many_round_robin() -> Result<()> {
@@ -1588,7 +1613,7 @@ mod tests {
15881613

15891614
// repartition from 1 input to 4 output
15901615
let output_partitions =
1591-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1616+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4), 8).await?;
15921617

15931618
assert_eq!(4, output_partitions.len());
15941619
assert_eq!(13, output_partitions[0].len());
@@ -1608,7 +1633,7 @@ mod tests {
16081633

16091634
// repartition from 3 input to 1 output
16101635
let output_partitions =
1611-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1636+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1), 8).await?;
16121637

16131638
assert_eq!(1, output_partitions.len());
16141639
assert_eq!(150, output_partitions[0].len());
@@ -1625,7 +1650,7 @@ mod tests {
16251650

16261651
// repartition from 3 input to 5 output
16271652
let output_partitions =
1628-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1653+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8).await?;
16291654

16301655
assert_eq!(5, output_partitions.len());
16311656
assert_eq!(30, output_partitions[0].len());
@@ -1648,6 +1673,7 @@ mod tests {
16481673
&schema,
16491674
partitions,
16501675
Partitioning::Hash(vec![col("c0", &schema)?], 8),
1676+
8,
16511677
)
16521678
.await?;
16531679

@@ -1670,8 +1696,11 @@ mod tests {
16701696
schema: &SchemaRef,
16711697
input_partitions: Vec<Vec<RecordBatch>>,
16721698
partitioning: Partitioning,
1699+
batch_size: usize,
16731700
) -> Result<Vec<Vec<RecordBatch>>> {
1674-
let task_ctx = Arc::new(TaskContext::default());
1701+
let session_config = SessionConfig::new().with_batch_size(batch_size);
1702+
let task_ctx =
1703+
Arc::new(TaskContext::default().with_session_config(session_config));
16751704
// create physical plan
16761705
let exec =
16771706
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
@@ -1702,7 +1731,8 @@ mod tests {
17021731
vec![partition.clone(), partition.clone(), partition.clone()];
17031732

17041733
// repartition from 3 input to 5 output
1705-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1734+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8)
1735+
.await
17061736
});
17071737

17081738
let output_partitions = handle.join().await.unwrap().unwrap();
@@ -1898,7 +1928,9 @@ mod tests {
18981928
// with different compilers, we will compare the same execution with
18991929
// and without dropping the output stream.
19001930
async fn hash_repartition_with_dropping_output_stream() {
1901-
let task_ctx = Arc::new(TaskContext::default());
1931+
let session_config = SessionConfig::new().with_batch_size(4);
1932+
let task_ctx =
1933+
Arc::new(TaskContext::default().with_session_config(session_config));
19021934
let partitioning = Partitioning::Hash(
19031935
vec![Arc::new(crate::expressions::Column::new(
19041936
"my_awesome_field",
@@ -1950,14 +1982,14 @@ mod tests {
19501982
});
19511983
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
19521984

1953-
fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1954-
batch
1955-
.into_iter()
1956-
.sorted_by_key(|b| format!("{b:?}"))
1957-
.collect()
1958-
}
1985+
let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
1986+
let items_set_with_drop: HashSet<&str> =
1987+
items_vec_with_drop.iter().copied().collect();
19591988

1960-
assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1989+
assert_eq!(
1990+
items_set_with_drop.symmetric_difference(&items_set).count(),
1991+
0
1992+
);
19611993
}
19621994

19631995
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
@@ -2396,6 +2428,7 @@ mod test {
23962428
use arrow::compute::SortOptions;
23972429
use arrow::datatypes::{DataType, Field, Schema};
23982430
use datafusion_common::assert_batches_eq;
2431+
use datafusion_execution::config::SessionConfig;
23992432

24002433
use super::*;
24012434
use crate::test::TestMemoryExec;
@@ -2507,8 +2540,10 @@ mod test {
25072540
let runtime = RuntimeEnvBuilder::default()
25082541
.with_memory_limit(64, 1.0)
25092542
.build_arc()?;
2510-
2511-
let task_ctx = TaskContext::default().with_runtime(runtime);
2543+
let session_config = SessionConfig::new().with_batch_size(4);
2544+
let task_ctx = TaskContext::default()
2545+
.with_runtime(runtime)
2546+
.with_session_config(session_config);
25122547
let task_ctx = Arc::new(task_ctx);
25132548

25142549
// Create physical plan with order preservation

0 commit comments

Comments
 (0)