From 4ddd519b296708553b9bd4f83d62e5af83c3b1b6 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Mon, 29 Jun 2026 17:14:43 +0800 Subject: [PATCH] refactor(hash-aggr): Migrate partial-reduce hash aggregation --- .../aggregates/aggregate_hash_table/common.rs | 23 +- .../aggregate_hash_table/final_table.rs | 12 +- .../aggregates/aggregate_hash_table/mod.rs | 4 +- .../partial_reduce_table.rs | 149 +++++++ .../aggregate_hash_table/partial_table.rs | 6 +- .../physical-plan/src/aggregates/mod.rs | 20 + .../src/aggregates/partial_reduce_stream.rs | 385 ++++++++++++++++++ 7 files changed, 578 insertions(+), 21 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs create mode 100644 datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 719fbe93e5416..0b44467405b6f 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -36,6 +36,8 @@ use crate::aggregates::{ /// Marker for raw rows -> partial state aggregation. pub(in crate::aggregates) struct PartialMarker; +/// Marker for partial state -> partial state aggregation. +pub(in crate::aggregates) struct PartialReduceMarker; /// Marker for raw rows -> partial state conversion without aggregation. pub(in crate::aggregates) struct PartialSkipMarker; /// Marker for partial state -> final value aggregation. @@ -182,7 +184,7 @@ impl AggregateHashTable { acc + state.group_values.size() + state.batch_group_indices.allocated_size() } - AggregateHashTableState::OutputtingMaterializedFinal(output) => { + AggregateHashTableState::OutputtingMaterialized(output) => { output.memory_size() } AggregateHashTableState::Done => 0, @@ -304,24 +306,25 @@ pub(super) enum AggregateHashTableState { Building(AggregateHashTableBuffer), /// Emitting results directly from group keys and aggregate state. Outputting(AggregateHashTableBuffer), - /// Materialize all the output results, and then incrementally output in the `OutputtingMaterializedFinal` state. + /// Materialize all output rows, then incrementally output slices from the + /// `OutputtingMaterialized` state. /// /// Note this is a temporary solution until the `GroupValues` issue is solved: /// Issue: - OutputtingMaterializedFinal(MaterializedFinalOutput), + OutputtingMaterialized(MaterializedOutput), Done, } -/// Fully evaluated final aggregate output and the next row offset to emit. +/// Fully materialized aggregate output and the next row offset to emit. /// -/// Final aggregate evaluation consumes accumulator state, so final output is -/// materialized once and then sliced to honor `batch_size` across output polls. -pub(super) struct MaterializedFinalOutput { +/// Some output paths consume accumulator state when emitting. Materialize those +/// rows once, then slice them to honor `batch_size` across output polls. +pub(super) struct MaterializedOutput { batch: RecordBatch, offset: usize, } -impl MaterializedFinalOutput { +impl MaterializedOutput { pub(super) fn new(batch: RecordBatch) -> Self { Self { batch, offset: 0 } } @@ -496,7 +499,7 @@ mod tests { use super::*; #[test] - fn materialized_final_output_slices_batches_until_exhausted() -> Result<()> { + fn materialized_output_slices_batches_until_exhausted() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new( "group_col", DataType::Int32, @@ -506,7 +509,7 @@ mod tests { schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], )?; - let mut output = MaterializedFinalOutput::new(batch); + let mut output = MaterializedOutput::new(batch); assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![1, 2]); assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![3, 4]); diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs index c3e4f831c4bbf..d6c8086bec15d 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -26,7 +26,7 @@ use crate::aggregates::AggregateExec; use super::common::{ AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, FinalMarker, - MaterializedFinalOutput, + MaterializedOutput, }; /// Methods specific to the aggregate hash table used in the final aggregation stage. @@ -68,7 +68,7 @@ impl AggregateHashTable { let output = self.materialize_final_output(state, output_schema)?; Ok(self.emit_next_materialized_batch(output, batch_size)) } - AggregateHashTableState::OutputtingMaterializedFinal(output) => { + AggregateHashTableState::OutputtingMaterialized(output) => { Ok(self.emit_next_materialized_batch(output, batch_size)) } AggregateHashTableState::Done => Ok(None), @@ -82,7 +82,7 @@ impl AggregateHashTable { &self, mut state: AggregateHashTableBuffer, output_schema: SchemaRef, - ) -> Result { + ) -> Result { // Final aggregate evaluation consumes accumulator state. Evaluate all // groups once, then slice the materialized batch on subsequent polls. let emit_to = EmitTo::All; @@ -96,19 +96,19 @@ impl AggregateHashTable { let batch = RecordBatch::try_new(output_schema, output)?; debug_assert!(batch.num_rows() > 0); - Ok(MaterializedFinalOutput::new(batch)) + Ok(MaterializedOutput::new(batch)) } fn emit_next_materialized_batch( &mut self, - mut output: MaterializedFinalOutput, + mut output: MaterializedOutput, batch_size: usize, ) -> Option { let batch = output.next_batch(batch_size); if output.is_exhausted() { self.state = AggregateHashTableState::Done; } else { - self.state = AggregateHashTableState::OutputtingMaterializedFinal(output); + self.state = AggregateHashTableState::OutputtingMaterialized(output); } batch } diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs index eb152f4128896..61dec1a5b46db 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs @@ -17,8 +17,10 @@ mod common; mod final_table; +mod partial_reduce_table; mod partial_table; pub(super) use common::{ - AggregateHashTable, FinalMarker, PartialMarker, PartialSkipMarker, + AggregateHashTable, FinalMarker, PartialMarker, PartialReduceMarker, + PartialSkipMarker, }; diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs new file mode 100644 index 0000000000000..f8d1c9da2557e --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::EmitTo; + +use crate::aggregates::AggregateExec; + +use super::common::{ + AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, + MaterializedOutput, PartialReduceMarker, +}; + +/// Methods specific to the aggregate hash table used in the partial-reduce stage. +impl AggregateHashTable { + pub(in crate::aggregates) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + vec![None; agg.aggr_expr.len()], + ) + } + + /// Emits the next batch of aggregated group keys and aggregate states. + /// + /// The output batch size is determined by `self.batch_size`. + /// + /// Returns `Some(batch)` for each emitted batch, `None` when output is + /// exhausted, and an internal error if polled in the `Building` state. + pub(in crate::aggregates) fn next_output_batch( + &mut self, + ) -> Result> { + let output_schema = Arc::clone(&self.output_schema); + let batch_size = self.batch_size; + // Take ownership of the output state. Note `emit_next_materialized_batch` + // updates state after it emits a materialized slice. + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + return Ok(None); + } + + let output = + self.materialize_partial_reduce_output(state, output_schema)?; + Ok(self.emit_next_materialized_batch(output, batch_size)) + } + AggregateHashTableState::OutputtingMaterialized(output) => { + Ok(self.emit_next_materialized_batch(output, batch_size)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } + + fn materialize_partial_reduce_output( + &self, + mut state: AggregateHashTableBuffer, + output_schema: SchemaRef, + ) -> Result { + // `state(EmitTo::All)` consumes accumulator state. Emit all groups once, + // then slice the materialized batch on subsequent polls. + let emit_to_all = EmitTo::All; + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to_all)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(emit_to_all)?); + } + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + Ok(MaterializedOutput::new(batch)) + } + + fn emit_next_materialized_batch( + &mut self, + mut output: MaterializedOutput, + batch_size: usize, + ) -> Option { + let batch = output.next_batch(batch_size); + if output.is_exhausted() { + self.state = AggregateHashTableState::Done; + } else { + self.state = AggregateHashTableState::OutputtingMaterialized(output); + } + batch + } + + pub(in crate::aggregates) fn aggregate_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.merge_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(in crate::aggregates) fn start_output(&mut self) -> Result<()> { + self.start_outputting(); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 9d226aa28b35f..6c9f44f4fbeee 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -91,10 +91,8 @@ impl AggregateHashTable { AggregateHashTableState::Building(_) => { internal_err!("next_output_batch must be called in the outputting state") } - AggregateHashTableState::OutputtingMaterializedFinal(_) => { - internal_err!( - "partial aggregate output should not materialize final output" - ) + AggregateHashTableState::OutputtingMaterialized(_) => { + internal_err!("partial aggregate output should not materialize output") } } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4f5b893578d74..0fc629924530d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -24,6 +24,7 @@ use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::aggregates::{ hash_aggregate::{FinalHashAggregateStream, PartialHashAggregateStream}, no_grouping::AggregateStream, + partial_reduce_stream::PartialReduceHashAggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; @@ -77,6 +78,7 @@ pub mod group_values; mod hash_aggregate; mod no_grouping; pub mod order; +mod partial_reduce_stream; mod row_hash; mod skip_partial; mod topk; @@ -527,6 +529,9 @@ enum StreamType { /// Partial stage of the hash aggregation /// Input output scheme: initial input -> partial state PartialHash(PartialHashAggregateStream), + /// Partial-reduce stage of the hash aggregation + /// Input output scheme: partial state -> partial state + PartialReduceHash(PartialReduceHashAggregateStream), /// Final stage of the hash aggregation /// Input output scheme: partial state -> final result FinalHash(FinalHashAggregateStream), @@ -550,6 +555,7 @@ impl From for SendableRecordBatchStream { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::PartialHash(stream) => Box::pin(stream), + StreamType::PartialReduceHash(stream) => Box::pin(stream), StreamType::FinalHash(stream) => Box::pin(stream), StreamType::GroupedHash(stream) => Box::pin(stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), @@ -1028,6 +1034,12 @@ impl AggregateExec { )?)); } + if self.should_use_partial_reduce_hash_stream() { + return Ok(StreamType::PartialReduceHash( + PartialReduceHashAggregateStream::new(self, context, partition)?, + )); + } + if self.should_use_final_hash_stream(context) { return Ok(StreamType::FinalHash(FinalHashAggregateStream::new( self, context, partition, @@ -1069,6 +1081,14 @@ impl AggregateExec { && self.group_by.is_single() } + fn should_use_partial_reduce_hash_stream(&self) -> bool { + self.mode == AggregateMode::PartialReduce + && self.limit_options.is_none() + && self.input_order_mode == InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + } + /// See comments in `PartialHashAggregateStream` limit optimization section fn limit_options_supported_by_hash_stream(&self) -> bool { self.limit_options.is_none() || self.is_unordered_unfiltered_group_by_distinct() diff --git a/datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs b/datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs new file mode 100644 index 0000000000000..1a4980c89851a --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partial-reduce hash aggregation stream implementation. +//! +//! This stream is part of the incremental migration from +//! [`crate::aggregates::row_hash::GroupedHashAggregateStream`]. +//! +//! See issue for details: + +use std::ops::ControlFlow; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use futures::stream::{Stream, StreamExt}; + +use super::AggregateExec; +use super::aggregate_hash_table::{AggregateHashTable, PartialReduceMarker}; +use crate::metrics::{BaselineMetrics, RecordOutput, SpillMetrics}; +use crate::stream::EmptyRecordBatchStream; +use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream}; + +/// Hash aggregation can combine multiple partial stages before final +/// evaluation. This stream implements the partial-reduce stage. +/// +/// # Example +/// +/// SELECT k, AVG(v) FROM t GROUP BY k; +/// +/// ## Plan +/// AggregateExec(stage=final) +/// -- RepartitionExec(hash(k)) +/// ---- AggregateExec(stage=partial_reduce) +/// ------ RepartitionExec(hash(k)) +/// -------- AggregateExec(stage=partial) +/// +/// Note: the example plan is only intended to demonstrate this stream's semantics; +/// the default DataFusion SQL planner does not produce plans in this shape. +/// +/// This stream implements the middle partial-reduce aggregation in the plan above. +/// +/// The motivation is to reduce shuffling traffic in a distributed setting. See +/// +/// +/// ## Partial-Reduce Stage Behavior +/// Input: partial aggregate state rows +/// Output: merged partial aggregate state rows +/// +/// This stage is useful for tree-reduce plans. It consumes the same schema as +/// a final aggregate stage, but emits the same schema as a partial aggregate +/// stage. +pub(crate) struct PartialReduceHashAggregateStream { + /// Output schema: group columns followed by partial aggregate state columns. + schema: SchemaRef, + + /// Input batches containing partial aggregate state rows. + input: SendableRecordBatchStream, + + /// Execution metrics shared with the aggregate plan node. + baseline_metrics: BaselineMetrics, + + /// Memory reservation for group keys and accumulators. + reservation: MemoryReservation, + + /// Tracks the high-level stream lifecycle. The hash table owns the lower-level + /// state for emitting output batches. + state: Option, +} + +/// States for partial-reduce hash aggregation processing. +// The typestate pattern mirrors the final stream and keeps the input/output +// semantics explicit for this mode. +enum PartialReduceHashAggregateState { + ReadingInput { + hash_table: AggregateHashTable, + }, + ProducingOutput { + hash_table: AggregateHashTable, + }, + Done, +} + +type PartialReduceHashAggregatePoll = Poll>>; +type PartialReduceHashAggregateStateTransition = ControlFlow< + ( + PartialReduceHashAggregatePoll, + PartialReduceHashAggregateState, + ), + PartialReduceHashAggregateState, +>; + +impl PartialReduceHashAggregateState { + fn hash_table(&self) -> &AggregateHashTable { + match self { + Self::ReadingInput { hash_table } | Self::ProducingOutput { hash_table } => { + hash_table + } + Self::Done => unreachable!("Done state does not hold a hash table"), + } + } + + fn hash_table_mut(&mut self) -> &mut AggregateHashTable { + match self { + Self::ReadingInput { hash_table } | Self::ProducingOutput { hash_table } => { + hash_table + } + Self::Done => unreachable!("Done state does not hold a hash table"), + } + } + + fn into_hash_table(self) -> AggregateHashTable { + match self { + Self::ReadingInput { hash_table } | Self::ProducingOutput { hash_table } => { + hash_table + } + Self::Done => unreachable!("Done state does not hold a hash table"), + } + } + + fn into_producing_output(self) -> Self { + Self::ProducingOutput { + hash_table: self.into_hash_table(), + } + } + + fn into_done(self) -> Self { + Self::Done + } +} + +impl PartialReduceHashAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert_eq!(agg.mode, super::AggregateMode::PartialReduce); + debug_assert_eq!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let batch_size = context.session_config().batch_size(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + + let hash_table = AggregateHashTable::::new( + agg, + partition, + Arc::clone(&schema), + batch_size, + )?; + + let reservation = + MemoryConsumer::new(format!("PartialReduceHashAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + baseline_metrics, + reservation, + state: Some(PartialReduceHashAggregateState::ReadingInput { hash_table }), + }) + } + + fn start_output( + &mut self, + hash_table: &mut AggregateHashTable, + ) -> Result<()> { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + hash_table.start_output() + } + + /// Handle ReadingInput state - aggregate partial state batches into the hash table. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_reading_input( + &mut self, + cx: &mut Context<'_>, + mut original_state: PartialReduceHashAggregateState, + ) -> PartialReduceHashAggregateStateTransition { + debug_assert!(matches!( + &original_state, + PartialReduceHashAggregateState::ReadingInput { .. } + )); + debug_assert!(original_state.hash_table().is_building()); + + match self.input.poll_next_unpin(cx) { + Poll::Pending => ControlFlow::Break((Poll::Pending, original_state)), + // Get a new input batch, aggregate it in the hash table + Poll::Ready(Some(Ok(batch))) => { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = original_state.hash_table_mut().aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + + if let Err(e) = self + .reservation + .try_resize(original_state.hash_table().memory_size()) + { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + + ControlFlow::Continue(original_state) + } + Poll::Ready(Some(Err(e))) => { + ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) + } + // Input ends, move to output state + Poll::Ready(None) => { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = self.start_output(original_state.hash_table_mut()); + timer.done(); + + match result { + Ok(()) => { + ControlFlow::Continue(original_state.into_producing_output()) + } + Err(e) => { + ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) + } + } + } + } + } + + /// Handle ProducingOutput state - emit merged partial aggregate state batches. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_producing_output( + &mut self, + mut original_state: PartialReduceHashAggregateState, + ) -> PartialReduceHashAggregateStateTransition { + debug_assert!(matches!( + &original_state, + PartialReduceHashAggregateState::ProducingOutput { .. } + )); + debug_assert!(!original_state.hash_table().is_building()); + + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = original_state.hash_table_mut().next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let _ = self + .reservation + .try_resize(original_state.hash_table().memory_size()); + debug_assert!(batch.num_rows() > 0); + let next_state = if original_state.hash_table().is_done() { + original_state.into_done() + } else { + original_state + }; + + ControlFlow::Break(( + Poll::Ready(Some(Ok(batch.record_output(&self.baseline_metrics)))), + next_state, + )) + } + Ok(None) => { + let _ = self.reservation.try_resize(0); + ControlFlow::Continue(original_state.into_done()) + } + Err(e) => ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)), + } + } +} + +impl Stream for PartialReduceHashAggregateStream { + type Item = Result; + + /// Entry point for the partial-reduce hash aggregate state machine. + /// + /// See comments in [`PartialReduceHashAggregateStream`] for high-level ideas. + /// + /// State transition graph: + /// + /// ```text + /// (start) + /// -> ReadingInput + /// The stream starts by polling partial-state input and merging those + /// states into the partial-reduce hash table. + /// + /// ReadingInput + /// -> ReadingInput + /// Aggregate one partial-state input batch, update the inner aggregate + /// hash table, and continue with the next input batch. + /// + /// -> ProducingOutput + /// Input was exhausted. Move to the next state to start outputting + /// merged partial aggregate states. + /// + /// ProducingOutput + /// -> ProducingOutput + /// One merged partial-state output batch was yielded; repeat to + /// continue producing output incrementally. + /// + /// -> Done + /// All merged partial-state output was emitted. + /// + /// Done + /// -> (end) + /// ``` + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let cur_state = self + .state + .take() + .expect("PartialReduceHashAggregateStream state should not be None"); + + let next_state = match cur_state { + state @ PartialReduceHashAggregateState::ReadingInput { .. } => { + self.handle_reading_input(cx, state) + } + state @ PartialReduceHashAggregateState::ProducingOutput { .. } => { + self.handle_producing_output(state) + } + state @ PartialReduceHashAggregateState::Done => { + let _ = self.reservation.try_resize(0); + self.state = Some(state); + return Poll::Ready(None); + } + }; + + match next_state { + ControlFlow::Continue(next_state) => { + self.state = Some(next_state); + continue; + } + ControlFlow::Break((poll, next_state)) => { + self.state = Some(next_state); + return poll; + } + } + } + } +} + +impl RecordBatchStream for PartialReduceHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +}