Skip to content

Commit 3dc9515

Browse files
committed
Emit in chunks
1 parent 0bd127f commit 3dc9515

File tree

6 files changed

+341
-35
lines changed

6 files changed

+341
-35
lines changed

datafusion/physical-plan/src/aggregates/group_values/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ pub trait GroupValues: Send {
111111
/// Emits the group values
112112
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
113113

114+
/// Signals that input is complete and drain mode should be activated
115+
fn input_done(&mut self) {}
116+
114117
/// Clear the contents and shrink the capacity to the size of the batch (free up memory usage)
115118
fn clear_shrink(&mut self, batch: &RecordBatch);
116119
}

datafusion/physical-plan/src/aggregates/group_values/row.rs

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ pub struct GroupValuesRows {
7676

7777
/// Random state for creating hashes
7878
random_state: RandomState,
79+
80+
/// State for iterative emission (activated after input is complete)
81+
/// When true, emit() uses offset-based slicing instead of copying remaining rows
82+
drain_mode: bool,
83+
84+
/// Current offset for drain mode emission (number of rows already emitted)
85+
emission_offset: usize,
7986
}
8087

8188
impl GroupValuesRows {
@@ -107,11 +114,19 @@ impl GroupValuesRows {
107114
hashes_buffer: Default::default(),
108115
rows_buffer,
109116
random_state: crate::aggregates::AGGREGATION_HASH_SEED,
117+
drain_mode: false,
118+
emission_offset: 0,
110119
})
111120
}
112121
}
113122

114123
impl GroupValues for GroupValuesRows {
124+
fn input_done(&mut self) {
125+
self.drain_mode = true;
126+
self.map.clear();
127+
self.map_size = 0;
128+
}
129+
115130
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
116131
// Convert the group keys into the row format
117132
let group_rows = &mut self.rows_buffer;
@@ -185,10 +200,22 @@ impl GroupValues for GroupValuesRows {
185200
self.len() == 0
186201
}
187202

203+
/// Returns the number of group values.
204+
///
205+
/// In drain mode (after `input_done()`), returns remaining groups not yet emitted,
206+
/// which matches the accumulator state size for consistency.
188207
fn len(&self) -> usize {
189208
self.group_values
190209
.as_ref()
191-
.map(|group_values| group_values.num_rows())
210+
.map(|group_values| {
211+
let total_rows = group_values.num_rows();
212+
if self.drain_mode {
213+
// In drain mode, return remaining rows (not yet emitted)
214+
total_rows.saturating_sub(self.emission_offset)
215+
} else {
216+
total_rows
217+
}
218+
})
192219
.unwrap_or(0)
193220
}
194221

@@ -206,29 +233,43 @@ impl GroupValues for GroupValuesRows {
206233
output
207234
}
208235
EmitTo::First(n) => {
209-
let groups_rows = group_values.iter().take(n);
210-
let output = self.row_converter.convert_rows(groups_rows)?;
211-
// Clear out first n group keys by copying them to a new Rows.
212-
// TODO file some ticket in arrow-rs to make this more efficient?
213-
let mut new_group_values = self.row_converter.empty_rows(0, 0);
214-
for row in group_values.iter().skip(n) {
215-
new_group_values.push(row);
216-
}
217-
std::mem::swap(&mut new_group_values, &mut group_values);
218-
219-
self.map.retain(|(_exists_hash, group_idx)| {
220-
// Decrement group index by n
221-
match group_idx.checked_sub(n) {
222-
// Group index was >= n, shift value down
223-
Some(sub) => {
224-
*group_idx = sub;
225-
true
226-
}
227-
// Group index was < n, so remove from table
228-
None => false,
236+
if self.drain_mode {
237+
let start = self.emission_offset;
238+
let end = std::cmp::min(start + n, group_values.num_rows());
239+
let iter = group_values.iter().skip(start).take(end - start);
240+
let output = self.row_converter.convert_rows(iter)?;
241+
self.emission_offset = end;
242+
if self.emission_offset == group_values.num_rows() {
243+
group_values.clear();
244+
self.emission_offset = 0;
229245
}
230-
});
231-
output
246+
output
247+
} else {
248+
let groups_rows = group_values.iter().take(n);
249+
let output = self.row_converter.convert_rows(groups_rows)?;
250+
251+
// Clear out first n group keys by copying them to a new Rows.
252+
// TODO file some ticket in arrow-rs to make this more efficient?
253+
let mut new_group_values = self.row_converter.empty_rows(0, 0);
254+
for row in group_values.iter().skip(n) {
255+
new_group_values.push(row);
256+
}
257+
std::mem::swap(&mut new_group_values, &mut group_values);
258+
259+
self.map.retain(|(_exists_hash, group_idx)| {
260+
// Decrement group index by n
261+
match group_idx.checked_sub(n) {
262+
// Group index was >= n, shift value down
263+
Some(sub) => {
264+
*group_idx = sub;
265+
true
266+
}
267+
// Group index was < n, so remove from table
268+
None => false,
269+
}
270+
});
271+
output
272+
}
232273
}
233274
};
234275

@@ -255,6 +296,8 @@ impl GroupValues for GroupValuesRows {
255296
self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
256297
self.hashes_buffer.clear();
257298
self.hashes_buffer.shrink_to(count);
299+
self.drain_mode = false;
300+
self.emission_offset = 0;
258301
}
259302
}
260303

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 225 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,8 +1550,8 @@ mod tests {
15501550
use crate::RecordBatchStream;
15511551

15521552
use arrow::array::{
1553-
DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
1554-
UInt32Array, UInt64Array,
1553+
DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Builder,
1554+
LargeListBuilder, StringArray, StructArray, UInt32Array, UInt64Array,
15551555
};
15561556
use arrow::compute::{concat_batches, SortOptions};
15571557
use arrow::datatypes::{DataType, Int32Type};
@@ -1572,7 +1572,7 @@ mod tests {
15721572
use datafusion_physical_expr::Partitioning;
15731573
use datafusion_physical_expr::PhysicalSortExpr;
15741574

1575-
use futures::{FutureExt, Stream};
1575+
use futures::{FutureExt, Stream, StreamExt};
15761576
use insta::{allow_duplicates, assert_snapshot};
15771577

15781578
// Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -3145,4 +3145,226 @@ mod tests {
31453145
run_test_with_spill_pool_if_necessary(20_000, false).await?;
31463146
Ok(())
31473147
}
3148+
3149+
#[tokio::test]
3150+
async fn test_chunked_group_emission() -> Result<()> {
3151+
let schema = Arc::new(Schema::new(vec![
3152+
Field::new("group_id", DataType::UInt32, false),
3153+
Field::new("value", DataType::Float64, false),
3154+
]));
3155+
3156+
let num_groups = 100_000;
3157+
let group_ids: Vec<u32> = (0..num_groups).collect();
3158+
let values: Vec<f64> = (0..num_groups).map(|i| i as f64).collect();
3159+
3160+
let batch = RecordBatch::try_new(
3161+
Arc::clone(&schema),
3162+
vec![
3163+
Arc::new(UInt32Array::from(group_ids)),
3164+
Arc::new(Float64Array::from(values)),
3165+
],
3166+
)?;
3167+
3168+
let input =
3169+
TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
3170+
3171+
let group_by = PhysicalGroupBy::new_single(vec![(
3172+
col("group_id", &schema)?,
3173+
"group_id".to_string(),
3174+
)]);
3175+
3176+
let aggregates = vec![Arc::new(
3177+
AggregateExprBuilder::new(count_udaf(), vec![col("value", &schema)?])
3178+
.schema(Arc::clone(&schema))
3179+
.alias("COUNT(value)")
3180+
.build()?,
3181+
)];
3182+
3183+
// Use a small batch size to force chunked emission
3184+
let batch_size = 100;
3185+
let session_config = SessionConfig::new().with_batch_size(batch_size);
3186+
3187+
let task_ctx =
3188+
Arc::new(TaskContext::default().with_session_config(session_config));
3189+
3190+
let aggregate = Arc::new(AggregateExec::try_new(
3191+
AggregateMode::Single,
3192+
group_by,
3193+
aggregates,
3194+
vec![None],
3195+
input,
3196+
Arc::clone(&schema),
3197+
)?);
3198+
3199+
let mut stream = aggregate.execute(0, task_ctx)?;
3200+
let mut total_rows = 0;
3201+
let mut batch_count = 0;
3202+
let mut max_batch_size = 0;
3203+
3204+
// Collect all batches and verify they are chunked
3205+
while let Some(result) = stream.next().await {
3206+
let batch = result?;
3207+
let batch_rows = batch.num_rows();
3208+
total_rows += batch_rows;
3209+
batch_count += 1;
3210+
max_batch_size = max_batch_size.max(batch_rows);
3211+
3212+
// Each batch should be <= batch_size (except possibly the last one)
3213+
assert!(
3214+
batch_rows <= batch_size || batch_count == 1,
3215+
"Batch {batch_count} has {batch_rows} rows, expected <= {batch_size}"
3216+
);
3217+
}
3218+
3219+
// Verify we got all groups
3220+
assert_eq!(total_rows, num_groups as usize, "Should emit all groups");
3221+
3222+
// Verify chunking happened (should have multiple batches)
3223+
assert!(
3224+
batch_count > 1,
3225+
"Expected multiple batches for chunked emission, got {batch_count}"
3226+
);
3227+
3228+
// Verify no single huge batch was emitted
3229+
assert!(
3230+
max_batch_size <= batch_size,
3231+
"Max batch size {max_batch_size} should be <= {batch_size}"
3232+
);
3233+
3234+
Ok(())
3235+
}
3236+
3237+
/// Reproducer for the "long poll" issue in group by aggregations.
3238+
///
3239+
/// This test demonstrates the difference between:
3240+
/// 1. OLD BEHAVIOR (simulated with very large batch_size): Emits all groups at once,
3241+
/// causing a long blocking operation before the first batch is returned
3242+
/// 2. NEW BEHAVIOR (with small batch_size): Emits groups in chunks, allowing
3243+
/// incremental output without blocking the async runtime
3244+
#[tokio::test]
3245+
async fn test_long_poll_reproducer() -> Result<()> {
3246+
use datafusion_common::instant::Instant;
3247+
use std::time::Duration;
3248+
3249+
let num_groups = 1_000_000;
3250+
let schema = Arc::new(Schema::new(vec![
3251+
Field::new("group_id", DataType::UInt32, false),
3252+
Field::new("group_name", DataType::Utf8, false),
3253+
Field::new(
3254+
"group_list",
3255+
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
3256+
false,
3257+
),
3258+
Field::new("value", DataType::Float64, false),
3259+
]));
3260+
3261+
// Generate test data
3262+
let group_ids: Vec<u32> = (0..num_groups).collect();
3263+
let group_names: Vec<String> =
3264+
(0..num_groups).map(|i| format!("group_{i}")).collect();
3265+
3266+
let mut list_builder = LargeListBuilder::new(Int64Builder::new());
3267+
for i in 0..num_groups {
3268+
list_builder.append_value([Some(i as i64), Some((i + 1) as i64)]);
3269+
}
3270+
let group_lists = list_builder.finish();
3271+
let values: Vec<f64> = (0..num_groups).map(|i| i as f64).collect();
3272+
3273+
let batch = RecordBatch::try_new(
3274+
Arc::clone(&schema),
3275+
vec![
3276+
Arc::new(UInt32Array::from(group_ids)),
3277+
Arc::new(StringArray::from(group_names)),
3278+
Arc::new(group_lists),
3279+
Arc::new(Float64Array::from(values)),
3280+
],
3281+
)?;
3282+
3283+
let group_by = PhysicalGroupBy::new_single(vec![
3284+
(col("group_id", &schema)?, "group_id".to_string()),
3285+
(col("group_name", &schema)?, "group_name".to_string()),
3286+
(col("group_list", &schema)?, "group_list".to_string()),
3287+
]);
3288+
3289+
let aggregates = vec![Arc::new(
3290+
AggregateExprBuilder::new(count_udaf(), vec![col("value", &schema)?])
3291+
.schema(Arc::clone(&schema))
3292+
.alias("COUNT(value)")
3293+
.build()?,
3294+
)];
3295+
3296+
println!("Testing with {num_groups} groups (UInt32 + String + LargeList keys)");
3297+
3298+
// Helper to run the aggregation with a specific batch size
3299+
// Returns (time_to_first_emission, total_batch_count)
3300+
let run_scenario = |batch_size: usize| {
3301+
let schema = Arc::clone(&schema);
3302+
let batch = batch.clone();
3303+
let group_by = group_by.clone();
3304+
let aggregates = aggregates.clone();
3305+
3306+
async move {
3307+
let input = TestMemoryExec::try_new_exec(
3308+
&[vec![batch]],
3309+
Arc::clone(&schema),
3310+
None,
3311+
)?;
3312+
3313+
let session_config = SessionConfig::new().with_batch_size(batch_size);
3314+
let task_ctx =
3315+
Arc::new(TaskContext::default().with_session_config(session_config));
3316+
3317+
let aggregate = Arc::new(AggregateExec::try_new(
3318+
AggregateMode::Single,
3319+
group_by,
3320+
aggregates,
3321+
vec![None],
3322+
input,
3323+
schema,
3324+
)?);
3325+
3326+
let mut stream = aggregate.execute(0, task_ctx)?;
3327+
let start = Instant::now();
3328+
let mut first_emission = None;
3329+
let mut batch_count = 0;
3330+
3331+
while let Some(result) = stream.next().await {
3332+
if first_emission.is_none() {
3333+
first_emission = Some(start.elapsed());
3334+
}
3335+
result?;
3336+
batch_count += 1;
3337+
}
3338+
3339+
Ok::<(Duration, usize), DataFusionError>((
3340+
first_emission.unwrap_or_default(),
3341+
batch_count,
3342+
))
3343+
}
3344+
};
3345+
3346+
// Case 1: Chunked emission (small batch size)
3347+
let (time_chunked, count_chunked) = run_scenario(1024).await?;
3348+
println!("Chunked emission (1024): {time_chunked:?} ({count_chunked} batches)");
3349+
3350+
// Case 2: Blocking emission (large batch size)
3351+
let (time_blocking, count_blocking) =
3352+
run_scenario(num_groups as usize + 1000).await?;
3353+
println!("Blocking emission (all): {time_blocking:?} ({count_blocking} batches)");
3354+
3355+
assert!(
3356+
count_chunked > 1,
3357+
"Chunked emission should produce multiple batches"
3358+
);
3359+
assert_eq!(
3360+
count_blocking, 1,
3361+
"Blocking emission should produce single batch"
3362+
);
3363+
3364+
// Example output:
3365+
// Testing with 1000000 groups (UInt32 + String + LargeList keys)
3366+
// Chunked emission (1024): 2.1316265s (977 batches)
3367+
// Blocking emission (all): 2.815402s (1 batches)
3368+
Ok(())
3369+
}
31483370
}

0 commit comments

Comments
 (0)