diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54b..7fa62fb551aa 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,7 +35,7 @@ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; @@ -134,10 +134,31 @@ impl ExecutionPlan for CustomPlan { _partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::iter(self.batches.clone().into_iter().map(Ok)), - ))) + if self.batches.is_empty() { + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::empty(), + ))) + } else { + let schema_captured = self.schema().clone(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::iter(self.batches.clone().into_iter().map( + move |batch| { + let projection: Vec = schema_captured + .fields() + .iter() + .filter_map(|field| { + batch.schema().index_of(field.name()).ok() + }) + .collect(); + batch + .project(&projection) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + }, + )), + ))) + } } fn statistics(&self) -> Result { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 843d975c7d76..31faca1d8700 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; +use crate::coalesce::LimitedBatchCoalescer; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; @@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec { spill_stream, 1, // Each receiver handles one input partition BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), )) as SendableRecordBatchStream }) .collect::>(); @@ -959,7 +961,6 @@ impl ExecutionPlan for RepartitionExec { .into_iter() .next() .expect("at least one spill reader should exist"); - Ok(Box::pin(PerPartitionStream::new( schema_captured, rx.into_iter() @@ -970,6 +971,7 @@ impl ExecutionPlan for RepartitionExec { spill_stream, num_input_partitions, BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), )) as SendableRecordBatchStream) } }) @@ -1427,9 +1429,12 @@ struct PerPartitionStream { /// Execution metrics baseline_metrics: BaselineMetrics, + + batch_coalescer: LimitedBatchCoalescer, } impl PerPartitionStream { + #[allow(clippy::too_many_arguments)] fn new( schema: SchemaRef, receiver: DistributionReceiver, @@ -1438,9 +1443,10 @@ impl PerPartitionStream { spill_stream: SendableRecordBatchStream, num_input_partitions: usize, baseline_metrics: BaselineMetrics, + batch_size: usize, ) -> Self { Self { - schema, + schema: Arc::clone(&schema), receiver, _drop_helper: drop_helper, reservation, @@ -1448,6 +1454,18 @@ impl PerPartitionStream { state: StreamState::ReadingMemory, remaining_partitions: num_input_partitions, baseline_metrics, + batch_coalescer: LimitedBatchCoalescer::new(schema, batch_size, None), + } + } + + fn flush_remaining_batch( + &mut self, + ) -> Poll>> { + // Flush any remaining buffered batch + match self.batch_coalescer.finish() { + Ok(()) => Poll::Ready(self.batch_coalescer.next_completed_batch().map(Ok)), + + Err(e) => Poll::Ready(Some(Err(e))), } } @@ -1460,75 +1478,82 @@ impl PerPartitionStream { let _timer = cloned_time.timer(); loop { - match self.state { - StreamState::ReadingMemory => { - // Poll the memory channel for next message - let value = match self.receiver.recv().poll_unpin(cx) { - Poll::Ready(v) => v, - Poll::Pending => { - // Nothing from channel, wait - return Poll::Pending; - } - }; - - match value { - Some(Some(v)) => match v { - Ok(RepartitionBatch::Memory(batch)) => { - // Release memory and return batch - self.reservation - .lock() - .shrink(batch.get_array_memory_size()); - return Poll::Ready(Some(Ok(batch))); + loop { + match self.state { + StreamState::ReadingMemory => { + // Poll the memory channel for next message + let value = match self.receiver.recv().poll_unpin(cx) { + Poll::Ready(v) => v, + Poll::Pending => { + // Nothing from channel, wait + return Poll::Pending; } - Ok(RepartitionBatch::Spilled) => { - // Batch was spilled, transition to reading from spill stream - // We must block on spill stream until we get the batch - // to preserve ordering - self.state = StreamState::ReadingSpilled; + }; + + match value { + Some(Some(v)) => match v { + Ok(RepartitionBatch::Memory(batch)) => { + // Release memory and return batch + self.reservation + .lock() + .shrink(batch.get_array_memory_size()); + self.batch_coalescer.push_batch(batch)?; + break; + } + Ok(RepartitionBatch::Spilled) => { + // Batch was spilled, transition to reading from spill stream + // We must block on spill stream until we get the batch + // to preserve ordering + self.state = StreamState::ReadingSpilled; + continue; + } + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + }, + Some(None) => { + // One input partition finished + self.remaining_partitions -= 1; + if self.remaining_partitions == 0 { + // All input partitions finished + return self.flush_remaining_batch(); + } + // Continue to poll for more data from other partitions continue; } - Err(e) => { - return Poll::Ready(Some(Err(e))); + None => { + // Channel closed unexpectedly + return self.flush_remaining_batch(); } - }, - Some(None) => { - // One input partition finished - self.remaining_partitions -= 1; - if self.remaining_partitions == 0 { - // All input partitions finished - return Poll::Ready(None); - } - // Continue to poll for more data from other partitions - continue; - } - None => { - // Channel closed unexpectedly - return Poll::Ready(None); } } - } - StreamState::ReadingSpilled => { - // Poll spill stream for the spilled batch - match self.spill_stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - self.state = StreamState::ReadingMemory; - return Poll::Ready(Some(Ok(batch))); - } - Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(e))); - } - Poll::Ready(None) => { - // Spill stream ended, keep draining the memory channel - self.state = StreamState::ReadingMemory; - } - Poll::Pending => { - // Spilled batch not ready yet, must wait - // This preserves ordering by blocking until spill data arrives - return Poll::Pending; + StreamState::ReadingSpilled => { + // Poll spill stream for the spilled batch + match self.spill_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.state = StreamState::ReadingMemory; + self.batch_coalescer.push_batch(batch)?; + break; + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + // Spill stream ended, keep draining the memory channel + self.state = StreamState::ReadingMemory; + } + Poll::Pending => { + // Spilled batch not ready yet, must wait + // This preserves ordering by blocking until spill data arrives + return Poll::Pending; + } } } } } + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Poll::Ready(Some(Ok(batch))); + } } } } @@ -1575,9 +1600,9 @@ mod tests { use datafusion_common::exec_err; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common_runtime::JoinSet; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; - use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1588,7 +1613,7 @@ mod tests { // repartition from 1 input to 4 output let output_partitions = - repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?; + repartition(&schema, partitions, Partitioning::RoundRobinBatch(4), 8).await?; assert_eq!(4, output_partitions.len()); assert_eq!(13, output_partitions[0].len()); @@ -1608,7 +1633,7 @@ mod tests { // repartition from 3 input to 1 output let output_partitions = - repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?; + repartition(&schema, partitions, Partitioning::RoundRobinBatch(1), 8).await?; assert_eq!(1, output_partitions.len()); assert_eq!(150, output_partitions[0].len()); @@ -1625,7 +1650,7 @@ mod tests { // repartition from 3 input to 5 output let output_partitions = - repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?; + repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8).await?; assert_eq!(5, output_partitions.len()); assert_eq!(30, output_partitions[0].len()); @@ -1648,6 +1673,7 @@ mod tests { &schema, partitions, Partitioning::Hash(vec![col("c0", &schema)?], 8), + 8, ) .await?; @@ -1670,8 +1696,11 @@ mod tests { schema: &SchemaRef, input_partitions: Vec>, partitioning: Partitioning, + batch_size: usize, ) -> Result>> { - let task_ctx = Arc::new(TaskContext::default()); + let session_config = SessionConfig::new().with_batch_size(batch_size); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); // create physical plan let exec = TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?; @@ -1702,7 +1731,8 @@ mod tests { vec![partition.clone(), partition.clone(), partition.clone()]; // repartition from 3 input to 5 output - repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await + repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8) + .await }); let output_partitions = handle.join().await.unwrap().unwrap(); @@ -1898,7 +1928,9 @@ mod tests { // with different compilers, we will compare the same execution with // and without dropping the output stream. async fn hash_repartition_with_dropping_output_stream() { - let task_ctx = Arc::new(TaskContext::default()); + let session_config = SessionConfig::new().with_batch_size(4); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); let partitioning = Partitioning::Hash( vec![Arc::new(crate::expressions::Column::new( "my_awesome_field", @@ -1950,14 +1982,14 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - fn sort(batch: Vec) -> Vec { - batch - .into_iter() - .sorted_by_key(|b| format!("{b:?}")) - .collect() - } + let items_vec_with_drop = str_batches_to_vec(&batches_with_drop); + let items_set_with_drop: HashSet<&str> = + items_vec_with_drop.iter().copied().collect(); - assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); + assert_eq!( + items_set_with_drop.symmetric_difference(&items_set).count(), + 0 + ); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { @@ -2396,6 +2428,7 @@ mod test { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::assert_batches_eq; + use datafusion_execution::config::SessionConfig; use super::*; use crate::test::TestMemoryExec; @@ -2507,8 +2540,10 @@ mod test { let runtime = RuntimeEnvBuilder::default() .with_memory_limit(64, 1.0) .build_arc()?; - - let task_ctx = TaskContext::default().with_runtime(runtime); + let session_config = SessionConfig::new().with_batch_size(4); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config); let task_ctx = Arc::new(task_ctx); // Create physical plan with order preservation