From cc01921ce567d7b3aa13634e9540ac80c5b219ec Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Thu, 20 Nov 2025 18:45:26 +0100 Subject: [PATCH 1/4] Emit in chunks --- .../src/aggregates/group_values/mod.rs | 3 + .../src/aggregates/group_values/row.rs | 89 +++++-- .../physical-plan/src/aggregates/mod.rs | 228 +++++++++++++++++- .../src/aggregates/order/full.rs | 9 +- .../src/aggregates/order/partial.rs | 9 +- .../physical-plan/src/aggregates/row_hash.rs | 39 ++- 6 files changed, 342 insertions(+), 35 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 4bd7f03506a1..d23f1a9b4500 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -111,6 +111,9 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; + /// Signals that input is complete and drain mode should be activated + fn input_done(&mut self) {} + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, batch: &RecordBatch); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 34893fcc4ed9..dbf176f4ef25 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -76,6 +76,13 @@ pub struct GroupValuesRows { /// Random state for creating hashes random_state: RandomState, + + /// State for iterative emission (activated after input is complete) + /// When true, emit() uses offset-based slicing instead of copying remaining rows + drain_mode: bool, + + /// Current offset for drain mode emission (number of rows already emitted) + emission_offset: usize, } impl GroupValuesRows { @@ -107,11 +114,19 @@ impl GroupValuesRows { hashes_buffer: Default::default(), rows_buffer, random_state: crate::aggregates::AGGREGATION_HASH_SEED, + drain_mode: false, + emission_offset: 0, }) } } impl GroupValues for GroupValuesRows { + fn input_done(&mut self) { + self.drain_mode = true; + self.map.clear(); + self.map_size = 0; + } + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; @@ -185,10 +200,22 @@ impl GroupValues for GroupValuesRows { self.len() == 0 } + /// Returns the number of group values. + /// + /// In drain mode (after `input_done()`), returns remaining groups not yet emitted, + /// which matches the accumulator state size for consistency. fn len(&self) -> usize { self.group_values .as_ref() - .map(|group_values| group_values.num_rows()) + .map(|group_values| { + let total_rows = group_values.num_rows(); + if self.drain_mode { + // In drain mode, return remaining rows (not yet emitted) + total_rows.saturating_sub(self.emission_offset) + } else { + total_rows + } + }) .unwrap_or(0) } @@ -206,29 +233,43 @@ impl GroupValues for GroupValuesRows { output } EmitTo::First(n) => { - let groups_rows = group_values.iter().take(n); - let output = self.row_converter.convert_rows(groups_rows)?; - // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficient? - let mut new_group_values = self.row_converter.empty_rows(0, 0); - for row in group_values.iter().skip(n) { - new_group_values.push(row); - } - std::mem::swap(&mut new_group_values, &mut group_values); - - self.map.retain(|(_exists_hash, group_idx)| { - // Decrement group index by n - match group_idx.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => { - *group_idx = sub; - true - } - // Group index was < n, so remove from table - None => false, + if self.drain_mode { + let start = self.emission_offset; + let end = std::cmp::min(start + n, group_values.num_rows()); + let iter = group_values.iter().skip(start).take(end - start); + let output = self.row_converter.convert_rows(iter)?; + self.emission_offset = end; + if self.emission_offset == group_values.num_rows() { + group_values.clear(); + self.emission_offset = 0; } - }); - output + output + } else { + let groups_rows = group_values.iter().take(n); + let output = self.row_converter.convert_rows(groups_rows)?; + + // Clear out first n group keys by copying them to a new Rows. + // TODO file some ticket in arrow-rs to make this more efficient? + let mut new_group_values = self.row_converter.empty_rows(0, 0); + for row in group_values.iter().skip(n) { + new_group_values.push(row); + } + std::mem::swap(&mut new_group_values, &mut group_values); + + self.map.retain(|(_exists_hash, group_idx)| { + // Decrement group index by n + match group_idx.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + *group_idx = sub; + true + } + // Group index was < n, so remove from table + None => false, + } + }); + output + } } }; @@ -255,6 +296,8 @@ impl GroupValues for GroupValuesRows { self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); + self.drain_mode = false; + self.emission_offset = 0; } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 5fafce0bea16..2692fbda3c53 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1550,8 +1550,8 @@ mod tests { use crate::RecordBatchStream; use arrow::array::{ - DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray, - UInt32Array, UInt64Array, + DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Builder, + LargeListBuilder, StringArray, StructArray, UInt32Array, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::{DataType, Int32Type}; @@ -1572,7 +1572,7 @@ mod tests { use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::PhysicalSortExpr; - use futures::{FutureExt, Stream}; + use futures::{FutureExt, Stream, StreamExt}; use insta::{allow_duplicates, assert_snapshot}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -3145,4 +3145,226 @@ mod tests { run_test_with_spill_pool_if_necessary(20_000, false).await?; Ok(()) } + + #[tokio::test] + async fn test_chunked_group_emission() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("group_id", DataType::UInt32, false), + Field::new("value", DataType::Float64, false), + ])); + + let num_groups = 100_000; + let group_ids: Vec = (0..num_groups).collect(); + let values: Vec = (0..num_groups).map(|i| i as f64).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(group_ids)), + Arc::new(Float64Array::from(values)), + ], + )?; + + let input = + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; + + let group_by = PhysicalGroupBy::new_single(vec![( + col("group_id", &schema)?, + "group_id".to_string(), + )]); + + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(value)") + .build()?, + )]; + + // Use a small batch size to force chunked emission + let batch_size = 100; + let session_config = SessionConfig::new().with_batch_size(batch_size); + + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + aggregates, + vec![None], + input, + Arc::clone(&schema), + )?); + + let mut stream = aggregate.execute(0, task_ctx)?; + let mut total_rows = 0; + let mut batch_count = 0; + let mut max_batch_size = 0; + + // Collect all batches and verify they are chunked + while let Some(result) = stream.next().await { + let batch = result?; + let batch_rows = batch.num_rows(); + total_rows += batch_rows; + batch_count += 1; + max_batch_size = max_batch_size.max(batch_rows); + + // Each batch should be <= batch_size (except possibly the last one) + assert!( + batch_rows <= batch_size || batch_count == 1, + "Batch {batch_count} has {batch_rows} rows, expected <= {batch_size}" + ); + } + + // Verify we got all groups + assert_eq!(total_rows, num_groups as usize, "Should emit all groups"); + + // Verify chunking happened (should have multiple batches) + assert!( + batch_count > 1, + "Expected multiple batches for chunked emission, got {batch_count}" + ); + + // Verify no single huge batch was emitted + assert!( + max_batch_size <= batch_size, + "Max batch size {max_batch_size} should be <= {batch_size}" + ); + + Ok(()) + } + + /// Reproducer for the "long poll" issue in group by aggregations. + /// + /// This test demonstrates the difference between: + /// 1. OLD BEHAVIOR (simulated with very large batch_size): Emits all groups at once, + /// causing a long blocking operation before the first batch is returned + /// 2. NEW BEHAVIOR (with small batch_size): Emits groups in chunks, allowing + /// incremental output without blocking the async runtime + #[tokio::test] + async fn test_long_poll_reproducer() -> Result<()> { + use datafusion_common::instant::Instant; + use std::time::Duration; + + let num_groups = 1_000_000; + let schema = Arc::new(Schema::new(vec![ + Field::new("group_id", DataType::UInt32, false), + Field::new("group_name", DataType::Utf8, false), + Field::new( + "group_list", + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + // Generate test data + let group_ids: Vec = (0..num_groups).collect(); + let group_names: Vec = + (0..num_groups).map(|i| format!("group_{i}")).collect(); + + let mut list_builder = LargeListBuilder::new(Int64Builder::new()); + for i in 0..num_groups { + list_builder.append_value([Some(i as i64), Some((i + 1) as i64)]); + } + let group_lists = list_builder.finish(); + let values: Vec = (0..num_groups).map(|i| i as f64).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(group_ids)), + Arc::new(StringArray::from(group_names)), + Arc::new(group_lists), + Arc::new(Float64Array::from(values)), + ], + )?; + + let group_by = PhysicalGroupBy::new_single(vec![ + (col("group_id", &schema)?, "group_id".to_string()), + (col("group_name", &schema)?, "group_name".to_string()), + (col("group_list", &schema)?, "group_list".to_string()), + ]); + + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(value)") + .build()?, + )]; + + println!("Testing with {num_groups} groups (UInt32 + String + LargeList keys)"); + + // Helper to run the aggregation with a specific batch size + // Returns (time_to_first_emission, total_batch_count) + let run_scenario = |batch_size: usize| { + let schema = Arc::clone(&schema); + let batch = batch.clone(); + let group_by = group_by.clone(); + let aggregates = aggregates.clone(); + + async move { + let input = TestMemoryExec::try_new_exec( + &[vec![batch]], + Arc::clone(&schema), + None, + )?; + + let session_config = SessionConfig::new().with_batch_size(batch_size); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + aggregates, + vec![None], + input, + schema, + )?); + + let mut stream = aggregate.execute(0, task_ctx)?; + let start = Instant::now(); + let mut first_emission = None; + let mut batch_count = 0; + + while let Some(result) = stream.next().await { + if first_emission.is_none() { + first_emission = Some(start.elapsed()); + } + result?; + batch_count += 1; + } + + Ok::<(Duration, usize), DataFusionError>(( + first_emission.unwrap_or_default(), + batch_count, + )) + } + }; + + // Case 1: Chunked emission (small batch size) + let (time_chunked, count_chunked) = run_scenario(1024).await?; + println!("Chunked emission (1024): {time_chunked:?} ({count_chunked} batches)"); + + // Case 2: Blocking emission (large batch size) + let (time_blocking, count_blocking) = + run_scenario(num_groups as usize + 1000).await?; + println!("Blocking emission (all): {time_blocking:?} ({count_blocking} batches)"); + + assert!( + count_chunked > 1, + "Chunked emission should produce multiple batches" + ); + assert_eq!( + count_blocking, 1, + "Blocking emission should produce single batch" + ); + + // Example output: + // Testing with 1000000 groups (UInt32 + String + LargeList keys) + // Chunked emission (1024): 2.1316265s (977 batches) + // Blocking emission (all): 2.815402s (1 batches) + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index eb98611f79df..ecfd1e1ad860 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -106,7 +106,14 @@ impl GroupOrderingFull { assert!(*current >= n); *current -= n; } - State::Complete => panic!("invalid state: complete"), + State::Complete => { + // When input is complete, we're in "drain mode" where groups are being + // emitted iteratively without receiving new input. In this state: + // - No new groups will be added (input_done() was called) + // - We don't need to track group indices anymore + // - remove_groups() is called as part of emit(EmitTo::First(n)) but has + // no work to do since we're just draining accumulated groups + } } } diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index 476551a7ca21..53172ec6fbc4 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -174,7 +174,14 @@ impl GroupOrderingPartial { assert!(*current_sort >= n); *current_sort -= n; } - State::Complete => panic!("invalid state: complete"), + State::Complete => { + // When input is complete, we're in "drain mode" where groups are being + // emitted iteratively without receiving new input. In this state: + // - No new groups will be added (input_done() was called) + // - We don't need to track group indices anymore + // - remove_groups() is called as part of emit(EmitTo::First(n)) but has + // no work to do since we're just draining accumulated groups + } } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 2e1b70da284d..9b1bff01cbb7 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -64,6 +64,10 @@ pub(crate) enum ExecutionState { /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks ProducingOutput(RecordBatch), + /// Iteratively emitting groups from group_values in batch_size chunks. + /// This state is used after input is complete to avoid emitting all groups + /// in a single large batch that would block the async runtime. + DrainingGroups, /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -793,14 +797,36 @@ impl Stream for GroupedHashAggregateStream { } } + ExecutionState::DrainingGroups => { + let size = self.batch_size; + let remaining_groups = self.group_values.len(); + let emit_count = size.min(remaining_groups); + match self.emit(EmitTo::First(emit_count), false)? { + Some(batch) => { + if let Some(reduction_factor) = self.reduction_factor.as_ref() + { + reduction_factor.add_part(batch.num_rows()); + } + + return Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))); + } + None => { + self.exec_state = ExecutionState::Done; + continue; + } + } + } + ExecutionState::ProducingOutput(batch) => { - // slice off a part of the batch, if needed - let output_batch; let size = self.batch_size; + let output_batch; (self.exec_state, output_batch) = if batch.num_rows() <= size { ( if self.input_done { - ExecutionState::Done + // All groups consumed, switch to drain mode + ExecutionState::DrainingGroups } // In Partial aggregation, we also need to check // if we should trigger partial skipping @@ -814,8 +840,7 @@ impl Stream for GroupedHashAggregateStream { batch.clone(), ) } else { - // output first batch_size rows - let size = self.batch_size; + // Slice off first batch_size rows let num_remaining = batch.num_rows() - size; let remaining = batch.slice(size, num_remaining); let output = batch.slice(0, size); @@ -1164,11 +1189,11 @@ impl GroupedHashAggregateStream { fn set_input_done_and_produce_output(&mut self) -> Result<()> { self.input_done = true; self.group_ordering.input_done(); + self.group_values.input_done(); let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { - let batch = self.emit(EmitTo::All, false)?; - batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) + ExecutionState::DrainingGroups } else { // If spill files exist, stream-merge them. self.update_merged_stream()?; From 149c8bab044737dfbea5efe790c9cbb10594889d Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Mon, 24 Nov 2025 16:38:47 +0100 Subject: [PATCH 2/4] use range-based access for rows --- datafusion/physical-plan/src/aggregates/group_values/row.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index dbf176f4ef25..3278e2af14df 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -236,7 +236,7 @@ impl GroupValues for GroupValuesRows { if self.drain_mode { let start = self.emission_offset; let end = std::cmp::min(start + n, group_values.num_rows()); - let iter = group_values.iter().skip(start).take(end - start); + let iter = (start..end).map(|i| group_values.row(i)); let output = self.row_converter.convert_rows(iter)?; self.emission_offset = end; if self.emission_offset == group_values.num_rows() { From b5d2072c1091e5dbd53310c903508a61a55e190e Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Mon, 24 Nov 2025 16:39:09 +0100 Subject: [PATCH 3/4] augment test with stats --- .../physical-plan/src/aggregates/mod.rs | 197 +++++++++++++----- 1 file changed, 145 insertions(+), 52 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2692fbda3c53..12e6760ffc3b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -3295,62 +3295,133 @@ mod tests { println!("Testing with {num_groups} groups (UInt32 + String + LargeList keys)"); - // Helper to run the aggregation with a specific batch size - // Returns (time_to_first_emission, total_batch_count) - let run_scenario = |batch_size: usize| { - let schema = Arc::clone(&schema); - let batch = batch.clone(); - let group_by = group_by.clone(); - let aggregates = aggregates.clone(); - - async move { - let input = TestMemoryExec::try_new_exec( - &[vec![batch]], - Arc::clone(&schema), - None, - )?; + // Case 1: Chunked emission with detailed timing + println!("\n=== Chunked emission (batch_size=8192) ==="); + let input_chunked = TestMemoryExec::try_new_exec( + &[vec![batch.clone()]], + Arc::clone(&schema), + None, + )?; + let session_config_chunked = SessionConfig::new().with_batch_size(8192); + let task_ctx_chunked = + Arc::new(TaskContext::default().with_session_config(session_config_chunked)); + let aggregate_chunked = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by.clone(), + aggregates.clone(), + vec![None], + input_chunked, + Arc::clone(&schema), + )?); - let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = - Arc::new(TaskContext::default().with_session_config(session_config)); + let mut stream_chunked = aggregate_chunked.execute(0, task_ctx_chunked)?; + let start = Instant::now(); + let mut first_emission = None; + let mut batch_count = 0; + let mut prev_batch_time = start; + let mut poll_times_chunked = Vec::new(); - let aggregate = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - group_by, - aggregates, - vec![None], - input, - schema, - )?); + while let Some(result) = stream_chunked.next().await { + let batch = result?; + batch_count += 1; + let batch_time = prev_batch_time.elapsed(); + poll_times_chunked.push(batch_time); + + if first_emission.is_none() { + first_emission = Some(start.elapsed()); + println!( + "First batch arrived at: {:?} ({} rows)", + start.elapsed(), + batch.num_rows() + ); + } - let mut stream = aggregate.execute(0, task_ctx)?; - let start = Instant::now(); - let mut first_emission = None; - let mut batch_count = 0; + prev_batch_time = Instant::now(); + } - while let Some(result) = stream.next().await { - if first_emission.is_none() { - first_emission = Some(start.elapsed()); - } - result?; - batch_count += 1; - } + let count_chunked = batch_count; + let total_chunked = start.elapsed(); + let min_poll_chunked = + poll_times_chunked.iter().min().copied().unwrap_or_default(); + let max_poll_chunked = + poll_times_chunked.iter().max().copied().unwrap_or_default(); + let avg_poll_chunked: Duration = + poll_times_chunked.iter().sum::() / poll_times_chunked.len() as u32; + + println!("Total batches: {count_chunked}"); + println!("Total time: {total_chunked:?}"); + println!("Poll times: min={min_poll_chunked:?}, max={max_poll_chunked:?}, avg={avg_poll_chunked:?}"); + + // Case 2: Blocking emission with detailed timing + println!("\n=== Blocking emission (batch_size > num_groups) ==="); + let input_blocking = TestMemoryExec::try_new_exec( + &[vec![batch.clone()]], + Arc::clone(&schema), + None, + )?; + let session_config_blocking = + SessionConfig::new().with_batch_size(num_groups as usize + 1000); + let task_ctx_blocking = + Arc::new(TaskContext::default().with_session_config(session_config_blocking)); + let aggregate_blocking = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + aggregates, + vec![None], + input_blocking, + Arc::clone(&schema), + )?); - Ok::<(Duration, usize), DataFusionError>(( - first_emission.unwrap_or_default(), - batch_count, - )) - } - }; + let mut stream_blocking = aggregate_blocking.execute(0, task_ctx_blocking)?; + let start = Instant::now(); + let mut count_blocking = 0; + let mut prev_batch_time = start; + let mut poll_times_blocking = Vec::new(); + + while let Some(result) = stream_blocking.next().await { + let batch = result?; + count_blocking += 1; + let batch_time = prev_batch_time.elapsed(); + poll_times_blocking.push(batch_time); + println!(" Batch {count_blocking} arrived at: {:?} ({} rows, batch creation took {batch_time:?})", start.elapsed(), batch.num_rows()); + prev_batch_time = Instant::now(); + } + + let time_blocking = start.elapsed(); + let min_poll_blocking = poll_times_blocking + .iter() + .min() + .copied() + .unwrap_or_default(); + let max_poll_blocking = poll_times_blocking + .iter() + .max() + .copied() + .unwrap_or_default(); + let avg_poll_blocking: Duration = poll_times_blocking.iter().sum::() + / poll_times_blocking.len() as u32; + + println!("Total time: {time_blocking:?}"); + println!("Poll times: min={min_poll_blocking:?}, max={max_poll_blocking:?}, avg={avg_poll_blocking:?}"); + + println!("\n=== Summary ==="); + println!("Total execution time:"); + println!(" Chunked: {total_chunked:?}"); + println!(" Blocking: {time_blocking:?}"); + println!( + " Overhead: {:.2}x with chunked", + total_chunked.as_secs_f64() / time_blocking.as_secs_f64() + ); + + println!("\nPoll duration (time between batches):"); + println!(" Chunked: min={min_poll_chunked:?}, max={max_poll_chunked:?}, avg={avg_poll_chunked:?}"); + println!(" Blocking: min={min_poll_blocking:?}, max={max_poll_blocking:?}, avg={avg_poll_blocking:?}"); - // Case 1: Chunked emission (small batch size) - let (time_chunked, count_chunked) = run_scenario(1024).await?; - println!("Chunked emission (1024): {time_chunked:?} ({count_chunked} batches)"); + println!("\nYield behavior:"); + println!(" Chunked: {count_chunked} batches (yields between each)"); + println!(" Blocking: {count_blocking} batch (single long stall)"); - // Case 2: Blocking emission (large batch size) - let (time_blocking, count_blocking) = - run_scenario(num_groups as usize + 1000).await?; - println!("Blocking emission (all): {time_blocking:?} ({count_blocking} batches)"); + println!("Benefit: max poll reduced from {max_poll_blocking:?} to {max_poll_chunked:?}."); assert!( count_chunked > 1, @@ -3362,9 +3433,31 @@ mod tests { ); // Example output: - // Testing with 1000000 groups (UInt32 + String + LargeList keys) - // Chunked emission (1024): 2.1316265s (977 batches) - // Blocking emission (all): 2.815402s (1 batches) + // === Chunked emission (batch_size=8192) === + // First batch arrived at: 2.210163709s (8192 rows) + // Total batches: 123 + // Total time: 2.869591125s + // Poll times: min=369.209µs, max=2.210162417s, avg=23.324541ms + + // === Blocking emission (batch_size > num_groups) === + // Batch 1 arrived at: 2.877907958s (1000000 rows, batch creation took 2.877906208s) + // Total time: 2.8790405s + // Poll times: min=2.877906208s, max=2.877906208s, avg=2.877906208s + + // === Summary === + // Total execution time: + // Chunked: 2.869591125s + // Blocking: 2.8790405s + // Overhead: 1.00x with chunked + + // Poll duration (time between batches): + // Chunked: min=369.209µs, max=2.210162417s, avg=23.324541ms + // Blocking: min=2.877906208s, max=2.877906208s, avg=2.877906208s + + // Yield behavior: + // Chunked: 123 batches (yields between each) + // Blocking: 1 batch (single long stall) + // Benefit: max poll reduced from 2.877906208s to 2.210162417s. Ok(()) } } From c4903b6f46ec45c0a55c3bdc71f6c8cc4f95728d Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Tue, 25 Nov 2025 10:32:32 +0100 Subject: [PATCH 4/4] add test_sorted_input --- .../physical-plan/src/aggregates/mod.rs | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 12e6760ffc3b..3cf49ac565dc 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1549,6 +1549,7 @@ mod tests { use crate::test::TestMemoryExec; use crate::RecordBatchStream; + use crate::sorts::sort::SortExec; use arrow::array::{ DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Builder, LargeListBuilder, StringArray, StructArray, UInt32Array, UInt64Array, @@ -1571,6 +1572,7 @@ mod tests { use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::{FutureExt, Stream, StreamExt}; use insta::{allow_duplicates, assert_snapshot}; @@ -3460,4 +3462,112 @@ mod tests { // Benefit: max poll reduced from 2.877906208s to 2.210162417s. Ok(()) } + + #[tokio::test] + async fn test_sorted_input() -> Result<()> { + // This test triggers emission with drain_mode=false by using sorted input. + let schema = Arc::new(Schema::new(vec![ + Field::new("group_id", DataType::UInt32, false), + Field::new( + "group_list", + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + // Create a single large batch with group boundaries within it + // GroupOrdering::Full should detect boundaries and emit incrementally + let num_rows = 100; + let group_ids: Vec = (0..num_rows).map(|i| i / 5).collect(); // 20 groups, 5 rows each + + let mut list_builder = LargeListBuilder::new(Int64Builder::new()); + for i in 0..num_rows { + list_builder.append_value([Some(i as i64), Some((i + 1) as i64)]); + } + let group_lists = list_builder.finish(); + + let values: Vec = (0..num_rows).map(|i| i as f64).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(group_ids)), + Arc::new(group_lists), + Arc::new(Float64Array::from(values)), + ], + ) + .expect("Failed to create batch"); + + let batches = vec![batch]; + + let sort_expr = [PhysicalSortExpr { + expr: col("group_id", &schema).expect("Failed to create column"), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let ordering: LexOrdering = sort_expr.into(); + + let memory_input = + TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None) + .expect("Failed to create input"); + + let sorted_input: Arc = + Arc::new(SortExec::new(ordering, memory_input)); + + let group_by = PhysicalGroupBy::new_single(vec![ + ( + col("group_id", &schema).expect("Failed to create column"), + "group_id".to_string(), + ), + ( + col("group_list", &schema).expect("Failed to create column"), + "group_list".to_string(), + ), + ]); + + let aggregates = vec![Arc::new( + AggregateExprBuilder::new( + count_udaf(), + vec![col("value", &schema).expect("Failed to create column")], + ) + .schema(Arc::clone(&schema)) + .alias("COUNT(value)") + .build() + .expect("Failed to build aggregate"), + )]; + + let task_ctx = Arc::new(TaskContext::default()); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggregates, + vec![None], + sorted_input, + Arc::clone(&schema), + ) + .expect("Failed to create aggregate"), + ); + + let mut stream = aggregate + .execute(0, task_ctx) + .expect("Failed to execute aggregate"); + + let mut batch_count = 0; + let mut total_rows = 0; + while let Some(result) = stream.next().await { + let batch = result?; + batch_count += 1; + total_rows += batch.num_rows(); + } + + assert!(batch_count > 0, "Should have at least one batch"); + assert!(total_rows > 0, "Should have at least some rows"); + + Ok(()) + } }